GPT - 2 文本生成任务全流程

news2025/4/20 7:07:23

数据集下载

数据预处理 

import json
import pandas as pd

all_data = []

with open("part-00018.jsonl",encoding="utf-8") as f:
    for line in f.readlines():
        data = json.loads(line)
        all_data.append(data["text"])

batch_size = 10000

for i in range(0,len(all_data),batch_size):
    begin = i
    end = i + batch_size

    df = pd.DataFrame({"content":all_data[begin:end]})
    df.to_csv(f"./data/{i}.csv",index=False)

GPT-2 模型的配置 

这部分代码的功能是初始化一个 GPT-2 模型的配置对象 GPT2Config,该对象将用于后续创建 GPT-2 模型实例。

方式一:在线配置

config = GPT2Config.from_pretrained("openai-community/gpt2",
                                    vocab_size=len(tokenizer),
                                    n_ctx=context_length,
                                    bos_token_id = tokenizer.bos_token_id,
                                    eos_token_id = tokenizer.eos_token_id,
                                    )

方式二:复制官网配置文件到本地

创建本地文件夹

复制官网配置文件到本地https://huggingface.co/openai-community/gpt2/blob/main/config.json

 

{
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "vocab_size": 50257
}
config = GPT2Config.from_pretrained("config/gpt2.config",
                                    vocab_size=len(tokenizer),
                                    n_ctx=context_length,
                                    bos_token_id = tokenizer.bos_token_id,
                                    eos_token_id = tokenizer.eos_token_id,
                                    )

模型映射、模型训练 

from glob import glob
import os
from torch.utils.data import Dataset
from datasets import load_dataset
import random
from transformers import BertTokenizerFast
from transformers import GPT2Config
from transformers import GPT2LMHeadModel
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer,TrainingArguments

def tokenize(element):
    outputs = tokenizer(element["content"],truncation=True,max_length=context_length,return_overflowing_tokens=True,return_length=True)

    input_batch = []

    for length,input_ids in zip(outputs["length"],outputs["input_ids"]):

        if length == context_length:
            input_batch.append(input_ids)

    return {"input_ids":input_batch}

if __name__ == "__main__":
    random.seed(1002)
    test_rate = 0.2
    context_length = 128

    all_files = glob(pathname=os.path.join("data","*"))

    test_file_list = random.sample(all_files,int(len(all_files)*test_rate))
    train_file_list = [i for i in all_files if i not in test_file_list]

    raw_datasets = load_dataset("csv",data_files={"train":train_file_list,"vaild":test_file_list},cache_dir="cache_data")


    tokenizer = BertTokenizerFast.from_pretrained("D:/bert-base-chinese")
    tokenizer.add_special_tokens({"bos_token":"[begin]","eos_token":"[end]"})

    tokenize_datasets = raw_datasets.map(tokenize,batched=True,remove_columns=raw_datasets["train"].column_names)

    config = GPT2Config.from_pretrained("config/gpt2.config",
                                        vocab_size=len(tokenizer),
                                        n_ctx=context_length,
                                        bos_token_id = tokenizer.bos_token_id,
                                        eos_token_id = tokenizer.eos_token_id,
                                        )

    model = GPT2LMHeadModel(config)
    model_size = sum([ t.numel() for t in model.parameters()])
    print(f"model_size: {model_size/1000/1000} M")

    data_collator = DataCollatorForLanguageModeling(tokenizer,mlm=False)

    args = TrainingArguments(
        learning_rate=1e-5,
        num_train_epochs=100,
        per_device_train_batch_size=10,
        per_device_eval_batch_size=10,
        eval_steps=2000,
        logging_steps=2000,
        gradient_accumulation_steps=5,
        weight_decay=0.1,
        warmup_steps=1000,
        lr_scheduler_type="cosine",
        save_steps=100,
        output_dir="model_output",
        fp16=True,
    )

    trianer = Trainer(
        model=model,
        args=args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        train_dataset=tokenize_datasets["train"],
        eval_dataset=tokenize_datasets["vaild"]
    )

    trianer.train()

文本生成交互界面

from transformers import GPT2LMHeadModel,BertTokenizerFast
import os

tokenizer = BertTokenizerFast.from_pretrained("bert-base-chinese")
model_path = os.path.join("model_output","checkpoint-100")

model = GPT2LMHeadModel.from_pretrained(model_path,pad_token_id=tokenizer.pad_token_id)
model = model.to("cuda")

while True:
    input_text = input("请输入:")
    input_ids = tokenizer.encode(input_text,return_tensors="pt")
    input_ids = input_ids.to("cuda")

    output = model.generate(input_ids,max_length=400,num_beams=5,repetition_penalty=1,early_stopping=True)

    output_text = tokenizer.decode(output[0],skip_special_tokens=True)
    print(f"输出:{output_text}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2335531.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

红宝书第四十三讲:基于资料的数据可视化工具简单介绍:D3.js 与 Canvas绘图

红宝书第四十三讲:基于资料的数据可视化工具简单介绍:D3.js 与 Canvas绘图12 资料取自《JavaScript高级程序设计(第5版)》。 查看总目录:红宝书学习大纲 一、D3.js:数据驱动文档的王者 1 核心特性&#x…

深入理解 Vue 的数据代理机制

何为数据代理? 通过一个对象代理对另一个对象中的属性的操作(读/写),就是数据代理。 要搞懂Vue数据代理这个概念,那我们就要从Object.defineProperty()入手 Object.defineProperty()是Vue中比较底层的一个方法&…

Java excel导入/导出导致内存溢出问题,以及解决方案

excel导入/导出导致内存溢出问题,以及解决方案 1、内存溢出问题导入功能重新修正,采用SAX的流式解析数据。并结合业务流程。导出功能:由于精细化了业务流程,导致比较代码比较冗杂,就只放出最简单的案例。 1、内存溢出问…

10 个最新 CSS 功能已在所有主流浏览器中得到支持

前言 CSS 不断发展,新功能使我们的工作更快、更简洁、更强大。得益于最新的浏览器改进(Baseline 2024),许多新功能现在可在所有主要引擎上使用。以下是您可以立即开始使用的10 CSS新功能。 1. Scrollbar-Gutter 和 Scrollbar-Co…

思科模拟器的单臂路由,交换机,路由器,路由器只要两个端口的话,连接三台电脑该怎么办,划分VLAN,dotlq协议

单臂路由 1. 需求:让三台电脑互通 2. 在二层交换机划分vlan,并加入; 3. 将连接二层交换机和路由器的端口f0/4改为trunk模式 4. 路由器:进入连接路由器的f0/0端口将端口开启 5. 进入每个vlan设dotlq协议并设网络IP&#xff08…

14 nginx 的 dns 缓存的流程

前言 这个是 2020年11月 记录的这个关于 nginx 的 dns 缓存的问题 docker 环境下面 前端A连到后端B 前端B连到后端A 最近从草稿箱发布这个问题的时候, 重新看了一下 发现该问题的记录中仅仅是 定位到了 nginx 这边的 dns 缓存的问题, 但是 并没有到细节, 没有到 具体的 n种…

实战教程:使用JetBrians Rider快速部署与调试PS5和Xbox上的UE项目

面向主机游戏开发者的重大新闻!在2024.3版本中,JetBrains Rider 增加了对 PlayStation5 和 Xbox 游戏主机的支持,您可以直接在您喜欢的游戏主机上构建、部署和调试 Unreal Engine 和自定义游戏引擎。 JetBrains Rider现在支持主机游戏开发&am…

专题十五:动态路由——BGP

一、BGP的基本概念 BGP(Border Gateway Protocol,边界网关协议)是一种用于在不同自治系统(AS)之间交换路由信息的外部网关协议(EGP)。通过TCP179端口建立连接。目前采用BGP4版本,IP…

hive数仓要点总结

1.OLTP和OLAP区别 OLTP(On-Line Transaction Processing)即联机事务处理,也称为面向交易的处理过程,其基本特征是前台接收的用户数据可以立即传送到计算中心进行处理,并在很短的时间内给出处理结果,是对用…

git安装(windows)

通过网盘分享的文件:资料(1) 链接: https://pan.baidu.com/s/1MAenYzcQ436MlKbIYQidoQ 提取码: evu6 点击next 可修改安装路径 默认就行 一般从命令行调用,所以不用创建。 用vscode,所以这么选择。

微信小程序实战案例 - 餐馆点餐系统 阶段1 - 菜单浏览

阶段 1 – 菜单浏览(超详细版) 目标:完成「首页=菜品卡片列表」 打好 UI 地基会从 云数据库 拉取 categories / dishes 并渲染打 Git Tag v1.0‑menu 1. 技术/知识点速览 知识点关键词说明云数据库db.collection().where().…

Dashboard的安装和基本使用

1.Dashboard简介: Dashboard是Kubernetes的Web图形用户界面(GUI),它为用户提供了一个直观的方式来管理和监控Kubernetes集群。 2.实验基础和前置条件: 本实验以Kubernetes集群环境搭建与初始化-CSDN博客为基础和前置…

英语单词 list 11

前言 这一个 list 是一些简单的单词。感觉这个浏览单词的方法比较低效,所以准备每天最多看一个 list ,真要提升英语水平,感觉还是得直接做阅读理解题。就像我们接触中文阅读材料一样,当然光知道这个表面意思还不够,还…

通义灵码助力Neo4J开发:快速上手与智能编码技巧

在 Web 应用开发中,Neo4J 作为一种图数据库,用于存储节点及节点间的关系。当图结构复杂化时,关系型数据库的查找效率会显著降低,甚至无法有效查找,这时 Neo4J 的优势便凸显出来。然而,由于其独特的应用场景…

高性能文件上传服务

高性能文件上传服务 —— 您业务升级的不二选择 在当今互联网数据量激增、文件体积日益庞大的背景下,高效、稳定的文件上传方案显得尤为重要。我们的文件分块上传服务端采用业界领先的 Rust HTTP 框架 Hyperlane 开发,凭借其轻量级、低延时和高并发的特…

Java Lambda 表达式详解:发展史、语法、使用场景及代码示例

Java Lambda 表达式详解:发展史、语法、使用场景及代码示例 1. Lambda 表达式的发展史 背景与动机 JDK 7 前:Java的匿名内部类虽强大,但代码冗余(如事件监听器、集合遍历)。JDK 8(2014)&#…

【从0到1学Elasticsearch】Elasticsearch从入门到精通(下)

我们在【从0到1学Elasticsearch】Elasticsearch从入门到精通(上)这边文章详细讲解了如何创建索引库和文档及javaAPI操作,但是在实战当中,我们还需要根据一些特殊字段对文档进行查找搜索,仅仅靠id查找文档是显然不够的。…

Python实现贪吃蛇二

上篇文章Python实现贪吃蛇一,实现了一个贪吃蛇的基础版本,但存在一些不足,也缺乏一些乐趣。本篇文章将对其进行一些改进,主要修改/实现以下几点: 1、解决食物随机生成的位置与蛇身重合问题 2、蛇身移动加速/减速功能 3…

基于51单片机的正负5V数字电压表( proteus仿真+程序+设计报告+讲解视频)

基于51单片机的正负5V数字电压表( proteus仿真程序设计报告讲解视频) 仿真图proteus7.8及以上 程序编译器:keil 4/keil 5 编程语言:C语言 设计编号:S0101 1. 主要功能: 设计一个基于51单片机数字电压表 1、能够…

Java雪花算法

以下是用Java实现的雪花算法代码示例,包含详细注释和异常处理: 代码下面有解析 public class SnowflakeIdGenerator {// 起始时间戳(2020-01-01 00:00:00)private static final long START_TIMESTAMP 1577836800000L;// 各部分…