用强化学习神包trl轻松实现GPT2可控文本生成

news2025/7/30 5:43:08

来源:投稿 作者:Sally can wait
编辑:学姐

模型github: lvwerra/trl: Train transformer language models with reinforcement learning. (github.com)https://github.com/lvwerra/trl

这个项目是复现 ”Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al一文的paper, code,因为觉得它非常好用,所以跟着跑通这个项目,并加上自己的理解介绍给大家。

理论基础

什么是可控文本生成?

虽然GPT2已经能生成流畅的句子,但是在特定话题的控制和逻辑性上仍然和期望相去甚远。我们希望一个文本生成模型可以一贯地围绕一个话题进行续写,而不是漫无目的地续写下去,这就是可控文本生成的研究目标。

在特定的运用场景中,我们有时需要用文本生成的方式增广数据。这时候可控文本生成可以在保证标签不变的前提下产生出大量的“伪数据”。

而大模型如GPT3、chatGPT效果较好,但是并不开源,而且由于巨大的参数量,微调起来也是浩大的工程。所以大部分的可控文本生成研究还是围绕GPT2做文章。

强化学习和PPO

强化学习不同于监督学习。监督学习只是对给定的、封闭的训练-验证数据集做参数优化,再用优化后的参数指导模型做出正确的输出。而强化学习的特点表现在强化信号上,强化信号是对产生动作的好坏作一种评价 (通常为标量),因此模型在不断产出输出的同时也在不断获得针对该输出的反馈,用这个反馈来更新模型参数。只要反馈机制是合理的,那么强化学习就可以一直进行下去,而不会面临训练数据匮乏的问题。

PPO(近端策略优化,Proximal Policy Optimisation)是强化学习目前最有效的一种算法。和先前的强化学习算法相比,PPO它在每一步迭代中都会尝试计算新的策略,这样可以让损失函数最小化,同时还能保证与上一步迭代的策略间的偏差相对较小。

PPO 里面有两项:一项是优化的奖励,另一项是一个约束。约束是为了防止模型被微调得过于离谱,失去了原有的语言模型做流畅的文字生成的能力。

How it works?

用PPO算法优化GPT2大致分以下三个步骤:

  1. 续写:GPT2先根据当前权重,续写给出的句子。

  2. 评估:GPT2续写的结果会经过一个分类层,或者也可以采用人工的打分,重要的是最终产生出一个数值型的分数。

  3. 优化:上一步对生成句子的打分会用于更新序列中token的对数概率。除此之外,还需要引入一个新的奖惩机制:KL散度。这需要用一个参考模型(通常是微调前的预训练模型,如GPT2-base)计算微调模型的输出和参考模型的输出之间的KL散度,把它和之前步骤的打分加在一起作为奖励函数,目的是确保生成的句子不会过多地偏离参考语言模型。然后使用PPO算法进一步训练语言模型。

实战:强化学习让GPT2产生正向IMDB影评

我们用强化学习的方法微调英文版 GPT2(small),让它基于 IMDB 数据集生成正面电影评论。该模型先是读取数据集中真实的影评,用GPT2续写。为了奖励情感倾向为正的续写,我们使用BERT影评分类器来分析生成的句子的情绪,把分类器的输出作为PPO训练的奖励。如果GPT2的续写经过分类器判别为正向情感,那么直接将分类器在正向情感上的置信度作为奖励加到ppo_trainer里面;反之,如果GPT2的续写经过分类器判别为负面情感,那么它在分类器输出层,正向情感对应的置信度会是负数或者很低,同样地将这个置信度加入ppo_trainer,可以提示模型减少对此输出的学习。

1.安装依赖包

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge --yes

2.读取包

#torch==1.8, transformers==4.15.0
import torch
import wandb
import time
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
tqdm.pandas()
from datasets import load_dataset
from transformers import AutoTokenizer, pipeline

from trl.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
from trl.core import build_bert_batch_from_txt, listify_batch

3.设置需要用到的超参数

config = {
    "model_name": "lvwerra/gpt2-imdb",
    "cls_model_name": "lvwerra/distilbert-imdb",
    "steps": 20000,
    "batch_size": 256,
    "forward_batch_size": 16,
    "ppo_epochs": 4,   
    "txt_in_min_len": 2,
    "txt_in_max_len": 8,
    "txt_out_min_len": 4,
    "txt_out_max_len": 16,
    "lr": 1.41e-5,
    "init_kl_coef":0.2,
    "target": 6,
    "horizon":10000,
    "gamma":1,
    "lam":0.95,
    "cliprange": .2,
    "cliprange_value":.2,
    "vf_coef":.1, 
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe_device = 0 if torch.cuda.is_available() else -1

4.用wandb仪表盘监控训练过程中的各项指标、中间变量。首次使用需要注册一下。

wandb.init(name='run-42', project='gpt2-test', config=config, )

图:在训练过程中可以观察到训练过程中的中间变量。“query”和“response”分别表示IMDB的原始句子prompt(经过随机截断)和GPT2生成的续写,“reward”表示经过情感分类器之后的正向情感分值,越大表示情感越积极

5.加载IMDB数据集

ds = load_dataset('imdb', split='train')
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)

6.load一个集成在transformers pipeline里的影评分类器(此处也可以替换成别的分类器,只要有打分就行)

sent_kwargs = {
    "return_all_scores": True,
    "function_to_apply": "none",
    "batch_size": config["forward_batch_size"]
}#指定分类器输出的格式

sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=pipe_device)
text = 'this movie was really bad!!'
sentiment_pipe(text, **sent_kwargs)

#一条分类后的结果长这样:我们需要的是score
#[[{'label': 'NEGATIVE', 'score': 2.3350484371185303}, {'label': 'POSITIVE', 'score': -2.726576089859009}]]

这里注意必须要确保transformers版本是4.15.0,不同版本的Pipeline输出大有不同

7.加载预训练GPT2-small

gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name'])
gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained(config['model_name'])

gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

wandb.watch(gpt2_model, log='all') #观察模型

gpt2_model.to(device);
gpt2_model_ref.to(device);

#设置文本生成的参数
gen_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": gpt2_tokenizer.eos_token_id
}

8.长度控制+tokenize

class LengthSampler:
    def __init__(self, min_value, max_value):
        self.values = list(range(min_value, max_value))
    def __call__(self):
        return np.random.choice(self.values)
    
input_size = LengthSampler(config["txt_in_min_len"], config["txt_in_max_len"])
output_size = LengthSampler(config["txt_out_min_len"], config["txt_out_max_len"])
# 在tokenize之前,随机截断输入数据作为待续写的prompt,也随机确定续写长度,防止输入输出的长度过于单一

def tokenize(sample):
    sample["tokens"] = gpt2_tokenizer.encode(sample["review"])[:input_size()]
    sample["query"] = gpt2_tokenizer.decode(sample["tokens"])
    return sample

ds = ds.map(tokenize, batched=False)

def collater(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

dataloader = torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], collate_fn=collater)

9.正式训练

ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config)

total_ppo_epochs = int(np.ceil(config["steps"]/config['batch_size']))

for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(dataloader))):
    logs, timing = dict(), dict()
    t0 = time.time()
    query_tensors = [torch.tensor(t).long().to(device) for t in batch["tokens"]]
    
    #### Get response from gpt2
    t = time.time()
    response_tensors = []
    for i in range(config['batch_size']):
        gen_len = output_size()
        response = gpt2_model.generate(query_tensors[i].unsqueeze(dim=0),
                                       max_new_tokens=gen_len, **gen_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
    batch['response'] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors]
    timing['time/get_response'] = time.time()-t

    #### Compute sentiment score
    t = time.time()
    texts = [q + r for q,r in zip(batch['query'], batch['response'])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    #[[{'label': 'NEGATIVE', 'score': 0.27862095832824707}, {'label': 'POSITIVE', 'score': -0.5044471621513367}]]
    rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device) #each output has negative score(output[0]) and positive score(output[1])
    #如果一个prompt目前是negative,它的positive score是-0.5,那么加到奖励里面,相当于让它少学这个
    timing['time/get_sentiment_preds'] = time.time()-t
    
    #### Run PPO step 
    t = time.time()
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    timing['time/optimization'] = time.time()-t
     
    #### Log everything
    timing['time/epoch'] = time.time()-t0
    table_rows = [list(r) for r in zip(batch['query'], batch['response'], rewards.cpu().tolist())]
    logs.update({'game_log': wandb.Table(columns=['query', 'response', 'reward'], rows=table_rows)})
    logs.update(timing)
    logs.update(stats)
    logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy()
    logs['env/reward_std'] = torch.std(rewards).cpu().numpy()
    logs['env/reward_dist'] = rewards.cpu().numpy()
    wandb.log(logs)

在训练过程中观察仪表盘,发现reward是上升的,说明训练是有效的。

10.保存模型

gpt2_model.save_pretrained('gpt2-imdb-pos-v2', push_to_hub=True)
gpt2_tokenizer.save_pretrained('gpt2-imdb-pos-v2', push_to_hub=True)

11.模型评估

通过比较原始GPT产生的无限制的文本和微调后产生的受控的文本,我们发现微调过程明显地让模型产生出了正面情感倾向的影评。同样地,在合适的位置添加负号就可以重新训练出一个会产生负面情绪文本的GPT2。针对指定label产生伪数据,这在数据增强上具有很高的应用价值。

此外,本实验的奖励机制是情感倾向值,也可以把奖励机制换成任何你喜欢的评价指标,运用在其他话题的生成任务上,看看模型是否会按照这个方向来学习。

关注下方《学姐带你玩AI》🚀🚀🚀

回复“ACL

免费获取文本生成&机器学习顶会高分论文PDF

码字不易,欢迎大家点赞评论收藏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/394551.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++vector 简单实现

一。概述 vector是我们经常用的一个容器,其本质是一个线性数组。通过对动态内存的管理,增删改查数据,达到方便使用的目的。 作为一个线性表,控制元素个数,容量,开始位置的指针分别是: start …

Hive---拉链表

拉链表 文章目录拉链表定义用途案例全量流程增量流程合并过程第一步第二步第三步案例二(含分区)创建外部表orders增量分区表历史记录表定义 拉链表是一种数据模型,主要是针对数据仓库设计中表存储数据的方式而定义的,顾名思义&am…

从零开始学GeoServer源码十一(如何处理多个文件解析器Multipart Resolver引起的冲突问题)

目录前言1.现象2.排查问题3.找到问题4.解决问题5.总结前言 本文起源于我们遇到的一个问题,本来 GeoServer 使用的好好的,但是有天突然发现,无法在 GeoServer 中上传样式的 sld 文件了,报错 “No Multipart-config for Servlet” …

java.lang.IllegalArgumentException: itemView may not be null

报错截图:场景介绍:在使用recycleView 自动递增数据,且自动滚动到最新行; 当数据达到273条 时出现ANR;项目中 全部的列表适配器使用的三方库:BaseRecyclerViewAdapterHelper (很早之前的项目&am…

《SQL基础》16. 锁

锁锁全局锁表级锁表锁元数据锁意向锁行级锁行锁间隙锁临键锁锁 锁是计算机协调多个进程或线程并发访问某一资源的机制。在数据库中,除传统的计算资源(CPU、RAM、I/O)的争用以外,数据也是一种供许多用户共享的资源。如何保证数据并…

uniapp在线升级关联云空间

升级中心 uni-upgrade-center - App: https://ext.dcloud.net.cn/plugin?id4542 App升级中心 uni-upgrade-center文档: https://uniapp.dcloud.net.cn/uniCloud/upgrade-center.html#uni-upgrade-center-app 升级中心 uni-upgrade-center - Admin&#…

Ka频段需要更多带宽?

随着全球连接需求的增长,许多卫星通信(satcom)系统日益采用Ka频段,对数据速率的要求也水涨船高。目前,高性能信号链已经能支持数千兆瞬时带宽,一个系统中可能有成百上千个收发器,超高吞吐量数据速率已经成为现实。 另…

JavaWeb—HTML

目录 1、B/S 软件的结构 2、前端的开发流程 3、网页的组成部分 4、HTML 简介 5、创建 HTML 文件 6、HTML 文件的书写规范 7、HTML 标签介绍 8、常用标签介绍 8.1、font 字体标签 8.2、特殊字符 8.3、标题标签 8.4、超链接 ( **** 重 点 ,必 …

如何实现jwt鉴权机制之详解

jwt鉴权一是什么headerpayloadSignature二、如何实现生成 token校验token三、优缺点优点:缺点:一是什么 JWT(JSON Web Token),本质就是一个字符串书写规范,如下图,作用是用来在用户和服务器之间…

Wannacrypt蠕虫老树开花?又见Wannacrypt

Wannacrypt蠕虫是一个在2017年就出现的远古毒株,其利用永恒之蓝漏洞降维打击用户服务器,而后进行扩散勒索,曾经一度风靡全球,可谓是闻者伤心,听着落泪,因为这玩意解密是不可能 解密的。 而2023年的今天&am…

MCM 箱模型建模方法及大气 O3 来源解析实用干货

OBM 箱模型可用于模拟光化学污染的发生、演变过程,研究臭氧的生成机制和进行敏感性分析,探讨前体物的排放对光化学污染的影响。箱模型通常由化学机理、物理过程、初始条件、输入和输出模块构成,化学机理是其核心部分。MCM (Master Chemical M…

【每天学习一点新知识】JNDI注入

什么是JNDIJNDI是Java的一种API,为我们提供了查找和访问各种命名和目录服务的通用统一的接口。通过JNDI统一接口我们可以来访问各种不同类型的服务,例如远程方法调用(RMI),通用对象请求代理体系结构(CORBA&…

Qt QTreeView简单使用

QT-QTreeView使用方法 QTreeView: 用于显示树状结构数据,适用于树状结构数据的操作。 一、初始化 ​ 利用QStandardlternModel来初始化数据,标准的基于项数据的数据模型类, 每个项数据可以是任何数据类型。 // 初始化model QStandardItem…

工作实战之拦截器模式

目录 前言 一、结构中包含的角色 二、拦截器使用 1.拦截器角色 a.自定义拦截器UserValidateInterceptor,UserUpdateInterceptor,UserEditNameInterceptor b.拦截器配置者UserInterceptorChainConfigure,任意组装拦截器顺序 c.拦截器管理者…

VM安装FydeOS16.0

准备工作 1、已安装VMware Workstation虚拟机; 2、下载好系统文件; 3、打开VM、新建虚拟机; 一、下载 https://fydeos.com/download/vm 我选择的镜像1。等下载完成,我这感觉下载速度不快,通过onedrive下载要快。 …

Jfrog 搭建本地maven仓库以及上传Android库

Jfrog 下载 安装包下载地址:Download Artifactory OSS | JFrog 如果是想下载之前的版本,可以点击上面的Get code source ,如果是最新版本,直接点下面的下载就好。下面以Linux安装为例。 Jfrog安装 对于Linux而言,其实…

Java泛型深入

一. 泛型的概述和优势 泛型概述 泛型&#xff1a;是JDK5中引入的特性&#xff0c;可以在编译阶段约束操作的数据类型&#xff0c;并进行检查。泛型的格式&#xff1a;<数据类型>&#xff0c;注意&#xff1a;泛型只能支持引用数据类型。集合体系的全部接口和实现类都是…

Java刷题-----蓝桥杯省赛JavaC组第十二届(第二场)4-------------6

4、格点题目本题总分&#xff1a;10 分问题描述如果一个点 ( x , y ) 的两维坐标都是整数&#xff0c;即 x ∈ Z 且 y ∈ Z &#xff0c;则称这个点为一个格点。如果一个点 ( x , y ) 的两维坐标都是正数&#xff0c;即 x > 0 且 y > 0 &#xff0c;则称这个点在第一象限…

浅谈 Nodejs原型链污染

一直在做php的题目&#xff0c;对其它语言做的很少。刚好在西湖论剑2022复现时&#xff0c;遇到了一道原型链污染的题目&#xff0c;借此机会开始简单学习一下 Nodejs的洞 p&#x1f402;讲解的十分清楚&#xff0c;因此下面举例子就直接用p&#x1f402;的例子进行解释了 目…

SNMP学习和测试

学习 &#xff08;1&#xff09;SNMP是简单网络管理协议&#xff0c;但是多么晦涩我就不说了。 &#xff08;2&#xff09;SNMP工作在应用层&#xff0c;也就是通过socket实现的&#xff0c;基于UDP&#xff0c;端口161和162&#xff0c;161是用于和网管通信的端口&#xff0…