人工智能LLM模型:奖励模型的训练、PPO 强化学习的训练、RLHF

news2025/9/13 22:49:52

人工智能LLM模型:奖励模型的训练、PPO 强化学习的训练

1.奖励模型的训练

1.1大语言模型中奖励模型的概念

在大语言模型完成 SFT 监督微调后,下一阶段是构建一个奖励模型来对问答对作出得分评价。奖励模型源于强化学习中的奖励函数,能对当前的状态刻画一个分数,来说明这个状态产生的价值有多少。在大语言模型微调中的奖励模型是对输入的问题和答案计算出一个分数。输入的答案与问题匹配度越高,则奖励模型输出的分数也越高。

1.2 奖励模型的模型架构与损失函数

1.2.1 模型架构

奖励模型(RM 模型)将 SFT 模型最后一层的 softmax 去掉,即最后一层不用 softmax,改成一个线性层。RM 模型的输入是问题和答案,输出是一个标量即分数。

由于模型太大不够稳定,损失值很难收敛且小模型成本较低,因此,RM 模型采用参数量为 6B 的模型,而不使用 175B 的模型。

1.2.2 损失函数

奖励模型的训练数据是人工对问题的每个答案进行排名,如下图所示:

对于每个问题,给出若干答案,然后工人进行排序,而奖励模型就是利用排序的结果来进行反向传播训练。奖励模型的损失函数采用 Pairwise Ranking Loss,公式如下所示:

l o s s ( θ ) = − ( K 2 ​ ) 1 ​ E ( x , y w ​ , y l ​ )   D ​ [ l o g ( σ ( r θ ​ ( x , y w ​ ) − r θ ​ ( x , y l ​ ) ) ) ] loss(θ)=−(K2​)1​E(x,yw​,yl​) D​[log(σ(rθ​(x,yw​)−rθ​(x,yl​)))] loss(θ)=(K2​)1​E(x,yw,yl) D[log(σ(rθ(x,yw)rθ(x,yl)))]

其中:
D:人工对答案进行排序的数据集;
x:数据集D中的问题;
K:每个问题对应的答案数量;
yw​yl​:问题x对应的K个答案中的两个,且yw​的排序比yl​高,由于是一对,也称 pairwiserθ​(x,y):需要训练的 RM 模型,对于输入的一对xy得到的标量分数;
θ:RM 模型需要优化的参数。

如何理解 RM 模型的损失函数呢?

RM 模型的目标是使得排序高的答案yw​对应的标量分数要高于排序低的答案yl​对应的标量分数,且越高越好,也就是使得损失函数中的rθ​(x,yw​)−rθ​(x,yl​)这个差值越大越好。将相减后的分数通过 sigmoid 函数,差值变成 - 1 到 1 之间,由于 sigmoid 函数是单调递增的函数,因此σ(rθ​(x,yw​)−rθ​(x,yl​))越大越好。σ(rθ​(x,yw​)−rθ​(x,yl​))约接近 1,表示yw​yl​排序高,属于 1 这个分类,反正属于 - 1 这个分类,所以这里也可以看成是一个二分类问题。再加上 logistic 函数,也就是相当于交叉熵损失函数。对于每个问题都有K个答案,在损失函数前除以CK2​,使得损失函数值不会因为K的变化而变化太多。损失函数的最终目标是最小化loss(θ),与最大化rθ​(x,yw​)−rθ​(x,yl​)相对应。

奖励模型中每个问题对应的答案数量即K值为什么选 9 更合适,而不是选择 4 呢?

  • 进行标注的时候,需要花很多时间去理解问题,但答案之间比较相近,假设 4 个答案进行排序要 30 秒时间,那么 9 个答案排序可能就 40 秒就够了。9 个答案与 4 个答案相比生成的问答对多了 5 倍,从效率上来看非常划算;
  • K=9时,每次计算 loss 都有 36 项rθ​(x,y)需要计算,RM 模型的计算所花时间较多,但可以通过重复利用之前算过的值(也就是只需要计算 9 次即可),能节约很多时间。

奖励模型的损失函数为什么会比较答案的排序,而不是去对每一个答案的具体分数做一个回归?

每个人对问题的答案评分都不一样,无法使用一个统一的数值对每个答案进行打分。如果采用对答案具体得分回归的方式来训练模型,会造成很大的误差。但是,每个人对答案的好坏排序是基本一致的。通过排序的方式避免了人为的误差。

1.3 总结

奖励模型通过与人类专家进行交互,获得对于生成响应质量的反馈信号,从而进一步提升大语言模型的生成能力和自然度。与监督模型不同的是,奖励模型通过打分的形式使得生成的文本更加自然逼真,让大语言模型的生成能力更进一步。

2.PPO 强化学习的训练

2.1 PPO 强化学习概念

大语言模型完成奖励模型的训练后,下一个阶段是训练强化学习模型(RL 模型),也是最后一个阶段。大语言模型微调中训练 RL 模型采用的优化算法是 PPO(Proximal Policy Optimization,近端策略优化)算法,即对设定的目标函数通过随机梯度下降进行优化。近端策略优化是一种深度强化学习算法,用于训练智能体在复杂环境中学习和执行任务。通过智能体的训练,使得其在与环境的交互中能够最大化累积回报,从而达成指定任务目标。这里的智能体在大语言模型中指的就是 RL 模型。

2.2 PPO 强化学习原理

RL 模型的初始模型采用 SFT 微调之后的大语言预训练模型。训练 RL 模型的数据集只需要收集问题集(Prompt 集),不需要对问题进行标注。问题集通过 RL 模型生成答案文本,然后将问题和答案输入上一步训练的 RW 模型进行打分,来评价生成的文本质量,而训练 RL 模型的目标是使得生成的文本要在 RW 模型上获得尽可能高的得分。

将初始语言模型的微调任务建模为强化学习(RL)问题,需要定义策略(policy)、动作空间(action space)和奖励函数(reward function)等基本要素。

策略就是基于该语言模型,接收 prompt 作为输入,然后输出一系列文本(或文本的概率分布);而动作空间就是词表所有 token 在所有输出位置的排列组合;观察空间则是可能的输入 token 序列(即 prompt),为词表所有 token 在所有输入位置的排列组合;而奖励函数则是上一阶段训好的 RM 模型,配合一些策略层面的约束进行的奖励计算。该阶段流程如下图所示:

RL 模型训练的损失函数公式如下:

o b j e c t i v e ( ϕ ) = E ( x , y ) ∼ D π ϕ R L ​​​ [ r θ ​ ( x , y ) − β l o g ( π ϕ R L ​ ( y ∣ x ) / π S F T ( y ∣ x ) ) ] + γ E x ∼ D p r e t r a i n ​​ [ l o g ( π ϕ R L ​ ( x ) ) ] objective(ϕ)=E(x,y)∼DπϕRL​​​[rθ​(x,y)−βlog(πϕRL​(y∣x)/πSFT(y∣x))]+γEx∼Dpretrain​​[log(πϕRL​(x))] objective(ϕ)=E(x,y)DπϕRL​​​[rθ(x,y)βlog(πϕRL(yx)/πSFT(yx))]+γExDpretrain​​[log(πϕRL(x))]

其中:
πSFT:SFT 模型;
πϕRL​:强化学习中,模型叫做 Policy,πϕRL​就是需要调整的模型,即最终模型。初始化是πSFT(x,y)∼DπϕRL​​x是 RL 数据集中的问题,yx通过πϕRL​模型得到的答案;
rθ​(x,y):对问题x和答案y进行打分的 RM 模型;
πϕRL​(y∣x):问题x通过πϕRL​得到答案y的概率,即对于每一个y的预测和它的 softmax 的输出相乘;
πSFT(y∣x):问题x通过πSFT得到答案y的概率;
x∼Dpretrain​x是来自大语言模型预训练阶段的数据;
βγ:调整系数。

RL 模型的优化目标是使得损失函数越大越好,损失函数可以分为三个部分,打分部分、KL 散度部分以及预训练部分。

  • **打分部分:**将 RL 模型的问题数据集x,通过πϕRL​模型得到答案y,然后再把这对(x,y)代入 RW 模型进行打分,即损失函数公式中的rθ​(x,y)。该分数越高,代表模型生成的答案越好。
  • **KL 散度部分:**在每次更新参数后,πϕRL​会发生变化,x通过πϕRL​生成的y也会发生变化,而rθ​(x,y)奖励模型是根据πSFT模型的数据训练而来。如果πϕRL​πSFT差的太多,则会导致rθ​(x,y)的分数估算不准确。因此需要通过 KL 散度来计算,πϕRL​生成的答案分布和πSFT生成的答案分布之间的距离,使得两个模型之间不要差的太远。损失函数公式中的log(πϕRL​(y∣x)/πSFT(y∣x))就是在计算 KL 散度。由于 KL 散度是越小越好,而训练目标是损失函数越大越好,因此在前面需要加上一个负号。
  • **预训练部分:**预训练部分对应损失函数中的Ex∼Dpretrain​​[log(πϕRL​(x))]。如果没有该项,那么模型最终可能只对这一个任务能够做好,在别的任务上会发生性能下降。因此,需要将预训练阶段的目标函数加上,使得前面两个部分在新的数据集上做拟合的同时保证原始的数据也不会丢弃。

最终优化后的πϕRL​模型就是大语言模型的最终模型。

2.3 总结

通过强化学习的训练方法,迭代式的更新奖励模型(RW 模型)以及策略模型(RL 模型),让奖励模型对模型输出质量的刻画愈加精确,策略模型的输出则愈能与初始模型拉开差距,使得输出文本变得越来越符合人的认知。这种训练方法也叫做 RLHF。

目前,RLHF 技术对训练大语言模型具有极大的影响力,训练出来的效果好于之前的方法。但是,RLHF 训练出来的大语言模型仍然可能输出有害或事实上不准确的文本,需要不断不断改进。此外,在基于 RLHF 范式训练模型时,人工标注的成本还是非常高昂的,RLHF 性能最终仅能达到标注人员的知识水平。这里的人工标注主要是为 RM 模型标注输出文本的排序结果,而若想要用人工去撰写答案的方式来训练模型,那成本更是不可想象。

3.关键知识点

  1. 大语言模型微调中的奖励模型训练:1.奖励模型输入问答对,输出得分 2.奖励模型的损失函数目的是使得得分较高的答案比得分较低的答案尽可能大,3.奖励模型是判别式模型

  2. 奖励模型是:监督学习、强化学习、判别式模型

  3. 大语言模型训练中的PPO强化学习:1.在大语言模型训练中,强化学习模型架构与SFT监督微调的模型一样,2.RLHF中训练强化学习模型阶段不需要标注问题的答案 3.RLHF中的初始策略就是SFT模型

  4. 关于RLHF方法中RL模型训练的损失函数:1.RL模型的损失函数包含三个部分 2.RL模型的损失函数需要计算策略更新后的RL模型与SFT模型输出的KL散度 3.RL模型的损失函数需要计算大语言模型预训练阶段的损失函数 4.RL模型的损失函数要使得RL模型生成的文本在奖励模型中的得分越高越好

  5. RLHF本质上是通过人类的反馈来优化模型,生成的文本会更加的自然。

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

型生成的文本在奖励模型中的得分越高越好

  1. RLHF本质上是通过人类的反馈来优化模型,生成的文本会更加的自然。

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/764219.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高通芯片android进入EDL模式 下载 热启动 串口指令

参考:高通方案的Android设备几种开机模式的进入与退出_edl模式怎么退出_Rookie20190715的博客-CSDN博客 切换为EDL模式 向串口发送 4b 65 01 00 54 0f 7e 或者adb reboot edl

Ceph的安装部署

文章目录 一、存储基础1.1 单机存储设备1.2 单机存储的问题1.3分布式存储(软件定义的存储 SDS) 二、Ceph 简介2.1 Ceph 优势2.2 Ceph 架构2.3 Ceph 核心组件2.4 Pool、PG 和 OSD 的关系:2.5 OSD 存储后端2.6 Ceph 数据的存储过程2.7 Ceph 版本…

PID控制系列--(1、最形象的PID)

目录 1、 比例控制系统的标准结构2、最简单的例子3、第二个例子4、积分控制器6、微分控制7 总结 今天 看到了B站上一个叫洋葱auto的UP主搬来的介绍PID控制的视频,感觉讲得形象易懂,为便于让和我一样看了无数文章还是不能很好理解PID控制本质的人共同分享…

2. DATASETS DATALOADERS

2. DATASETS & DATALOADERS PyTorch提供了两个数据基元:torch.utils.data.DataLoader和torch.uutils.data.data集,允许使用预加载的数据集以及自己的数据。数据集存储样本及其相应的标签,DataLoader在数据集周围包装了一个可迭代项&…

Sentinel整合OpenFegin

之前学习了openFeign的使用&#xff0c;我是超链接 现在学习通过Sentinel来进行整合OpenFegin。 引入OpenFegin 我们需要在当前的8084项目中引入对应的依赖 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-sta…

网络套接字编程(一)(UDP)

gitee仓库&#xff1a;https://gitee.com/WangZihao64/linux/tree/master/chat_udp 预备知识 源IP地址和目的IP地址 它是用来标识网络中不同主机的地址。两台主机进行通信时&#xff0c;发送方需要知道自己往哪一台主机发送&#xff0c;这就需要知道接受方主机的的IP地址&am…

【数学建模】利用C语言来实现 太阳赤纬 太阳高度角 太阳方位角 计算和求解分析 树木树冠阴影面积与种植间距的编程计算分析研究

太阳赤纬的计算 #include <stdio.h> #include <math.h>double calculateDelta(int year, int month, int day, int hour, int minute, int second) {int n, n0;double t, theta, delta;// 计算n和n0n month * 30 day;n0 79.6764 0.2422 * (year - 1985) - ((y…

35+大龄程序员从焦虑到收入飙升:我的搞钱副业分享。

37岁大龄程序员&#xff0c;一度觉得自己的职场生涯到头了。既没有晋升和加薪的机会&#xff0c;外面的公司要么接不住我的薪资&#xff0c;要么就是卷得不行&#xff0c;无法兼顾工作和家庭&#xff0c;感觉陷入了死局…… 好在我又重新振作起来&#xff0c;决定用副业和兼职填…

2.3Listbox列表部件

2.3Listbox列表部件 创建主窗口 window tk.Tk() window.title(my window) window.geometry(200x200)创建一个label用于显示 var1 tk.StringVar() #创建变量 l tk.Label(window,bgyellow,width4,textvariablevar1) l.pack()创建一个方法用于按钮的点击事件 def print_s…

DateTimePicker基本用法

作用&#xff1a;日期时间控件&#xff0c;用于手动选择日期与时间。 常用属性&#xff1a; 常用事件&#xff1a; 后台代码示范&#xff1a; //日期变化时获取日期private void dateTimePicker1_ValueChanged(object sender, EventArgs e){textBox2.Text dateTimePicker1.Te…

(原创)适合小白的AI算法学习路线

大家好啊&#xff0c;我是董董灿。 之前写了一篇文章&#xff1a;有前途&#xff01;大模型也需要AI算子开发岗&#xff01;有同学看了之后&#xff0c;在问AI算子开发需要如何学习&#xff0c;有没有学习路线? 当然是有的了。 今天周末在家&#xff0c;就梳理了一下该岗位需…

pdf水印在哪里设置?超实用解决方法分享

在工作中&#xff0c;我们常常需要发送PDF文件给他人&#xff0c;为了保护文件的安全性&#xff0c;防止被他人盗用或篡改&#xff0c;我们通常会给PDF文件添加水印。添加水印可以有效地标识文件的所有权&#xff0c;并增加文件的可追溯性。然而&#xff0c;有许多人不清楚如何…

2023牛客多校第一场B Anticomplementary Triangle

①&#xff1a;有结论&#xff1a;面积最大的三角形即为所求 证明&#xff1a;若有点在面积最大的三角形对应 “ A n t i c o m p l e m e n t a r y T r i a n g l e ” “Anticomplementary Triangle” “AnticomplementaryTriangle”之外&#xff0c;一定能取得更大的面积。…

14款奔驰R400升级ACC自适应巡航系统,解放您双脚

有的时候你是否厌倦了不停的刹车、加油&#xff1f;是不是讨厌急刹车&#xff0c;为掌握不好车距而烦恼&#xff1f;如果是这样&#xff0c;那么就升级奔驰原厂ACC自适应式巡航控制系统&#xff0c;带排队自动辅助和行车距离警报功能&#xff0c;感受现代科技带给你的舒适安全和…

实时监测与报警,探索CMS系统在半导体设备安全管理中的作用

在半导体制造行业&#xff0c;设备的安全管理对于保障生产运行和员工安全至关重要。中央设备状态监控系统CMS&#xff08;central monitoring system&#xff09;是一种关键的解决方案&#xff0c;为企业提供实时监测和报警功能&#xff0c;有效应对设备安全管理的挑战。本文将…

【IDEA2023】解决IDEA中快捷键Alt+Enter不能引入局部变量

1、打开设置 File ➡️ Settings ➡️ Editor ➡️ Intentions 搜索refactorings&#xff0c;将Introduce local variable这个选项勾选上 将Introduce local variable这个选项勾选上 OK&#xff0c;Apply

怎么制作思维导图简单又漂亮?看看这几款常用模板

怎么制作思维导图简单又漂亮&#xff1f;制作思维导图可以帮助我们更好地梳理思路、整理信息。它可以让我们将复杂的信息变得易于理解和记忆&#xff0c;并且可以帮助我们更好地组织各种想法和概念。通过制作思维导图&#xff0c;我们可以更清晰地看到问题的本质&#xff0c;找…

3.Cesium中实体Entity创建(超详细)

前言 在学习 Cesium 的过程中&#xff0c;我发现官方文档冗长且阅读困难&#xff0c;为此我结合官方文档与自己的学习笔记&#xff0c;对其进行归类总结&#xff1b;本文中&#xff0c;我将介绍 Cesium 中创建实体的方法&#xff0c;并对其进行分类&#xff0c;帮助读者快速理解…

【中危】Apache StreamPipes <0.92.0 权限管理不当漏洞

漏洞描述 Apache StreamPipes 是一个开源的数据流处理框架。 Apache StreamPipes 受影响版本中由于 UserResource.java 中的 updateAppearanceMode、registerUser、registerService 函数未对用户身份进行验证&#xff0c;具有登录权限的普通用户可通过 {userId}/appearance/m…

火得不要不要的人工智能,SpringBoot实现人脸识别功能

需求分析 一、人脸注册 step1&#xff1a;人像采集。在注册页面上用html中video组件和js调用笔记本摄像头&#xff0c;并抓取人像图片。没有摄像头的笔记本、台式机的童鞋告辞吧&#xff0c;走好不送。。。 step2&#xff1a;人像上传至项目文件夹。将在页面采集到的人像数据…