使用 ORPO 微调 Llama 3

news2024/5/18 15:15:11

原文地址:https://towardsdatascience.com/fine-tune-llama-3-with-orpo-56cfab2f9ada

更便宜、更快的统一微调技术

2024 年 4 月 19 日

ORPO 是一种新的令人兴奋的微调技术,它将传统的监督微调和偏好校准阶段合并为一个过程。这减少了训练所需的计算资源和时间。此外,经验结果表明,在各种模型大小和基准上,ORPO 都优于其他配准方法。

在本文中,我们将使用 ORPO 和 TRL 库对新的 Llama 3 8B 模型进行微调。

ORPO

指令调整和偏好对齐是使大型语言模型(LLM)适应特定任务的基本技术。传统上,这涉及一个多阶段过程:1/ 对指令进行监督微调 (SFT),使模型适应目标领域;2/偏好调整方法,如人工反馈强化学习 (RLHF) 或直接偏好优化 (DPO),以提高生成首选响应而非拒绝响应的可能性。

12

不过,研究人员也发现了这种方法的局限性。虽然 SFT 能有效地使模型适应所需的领域,但却无意中增加了在生成首选答案的同时生成不想要的答案的概率。这就是为什么有必要进行偏好调整阶段,以拉大首选输出和拒绝输出的可能性之间的差距。

13

由 Hong 和 Lee(2024 年)提出的 ORPO 将指令调整和偏好调整结合到一个单一的、整体的训练过程中,为这一问题提供了一个优雅的解决方案。ORPO 修改了标准语言建模目标,将负对数似然损失与几率比(OR)项相结合。这种赔率损失会对被拒绝的反应进行弱惩罚,同时对偏好的反应进行强奖励,从而使模型能够同时学习目标任务并与人类偏好保持一致。

14

使用 ORPO 微调 Llama 3

Llama 3 是 Meta 开发的最新 LLM 系列。这些模型是在一个包含 15 万亿个词库(相比之下,Llama 2 包含 2T 个词库)的广泛数据集上训练的。目前已发布两种规模的模型:700 亿参数模型和较小的 80 亿参数模型。70B 模型已经表现出令人印象深刻的性能,在 MMLU 基准测试中获得 82 分,在 HumanEval 基准测试中获得 81.7 分。

Llama 3 模型还将上下文长度增加到 8,192 个标记(Llama 2 为 4,096 个标记),并有可能通过 RoPE 扩展到 32k。此外,这些模型还使用了具有 128K 标记词汇的新标记化器,从而将文本编码所需的标记数量减少了 15%。这个词汇量也是参数从 7B 增加到 8B 的原因。

15

ORPO 需要一个偏好数据集,包括提示、选择的答案和拒绝的答案。

按照惯例,我们先安装所需的库:

pip install -U transformers datasets accelerate peft trl bitsandbytes wandb

安装完成后,我们就可以导入必要的库并登录 W&B(可选):

import gc
import os
import torch
import wandb
from datasets import load_dataset
from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
wb_token = userdata.get('wandb')
wandb.login(key=wb_token)

如果你使用的是最新的 GPU,你还可以使用 Flash Attention 库,以更高效的方式替换默认的急切关注实现。

if torch.cuda.get_device_capability()[0] >= 8:
    !pip install -qqq flash-attn
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16

在下文中,我们将利用 bitsandbytes 以 4 位精度加载 Llama 3 8B 模型。然后,我们使用 QLoRA 的 PEFT 设置 LoRA 配置。我还使用方便的 setup_chat_format() 函数来修改模型和标记符,以支持 ChatML。它会自动应用聊天模板,添加特殊标记,并调整模型嵌入层的大小,以匹配新的词汇量大小。

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

既然模型已经准备好进行训练,我们就可以处理数据集了。我们加载 mlabonne/orpo-dpo-mix-40k 并使用 apply_chat_template() 函数将 "选择 "和 "拒绝 "列转换为 ChatML 格式。请注意,我只使用了 1,000 个样本,而不是整个数据集,因为这样运行时间太长。

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(10))
def format_chat_template(row):
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row
dataset = dataset.map(
    format_chat_template,
    num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

首先,我们需要设置几个超参数:

  • 学习率: 与传统的 SFT 甚至 DPO 相比,ORPO 使用非常低的学习率。这个 8e-6 的值来自原始论文,大致相当于 1e-5 的 SFT 学习率和 5e-6 的 DPO 学习率。我建议将其提高到 1e-6 左右,以进行真正的微调。
  • beta: 它是论文中的 $\lambda$ 参数,默认值为 0.1。原始论文的附录展示了如何通过烧蚀研究来选择它。
  • 其他参数,如 max_length 和批量大小,都设置为使用尽可能多的可用 VRAM(在此配置中约为 20 GB)。理想情况下,我们会对模型进行 3-5 个历元的训练,但这里我们会坚持使用 1 个历元。

最后,我们可以使用 ORPOTrainer 对模型进行训练。

orpo_args = ORPOConfig(
    learning_rate=8e-6,
    beta=0.1,
    lr_scheduler_type="linear",
    max_length=1024,
    max_prompt_length=512,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    report_to="wandb",
    output_dir="./results/",
)
trainer = ORPOTrainer(
    model=model,
    args=orpo_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(new_model)

在 L4 GPU 上对这 1,000 个样本进行模型训练耗时约 2 小时。我们来看看 W&B 图:

16

虽然损失减少了,但选择答案和拒绝答案之间的差异并不明显:平均余量和准确率分别仅略高于零和 0.5。

在原始论文中,作者在 Anthropic/hh-rlhf 数据集(161k 个样本)上训练了 10 个历时的模型,这比我们快速运行的时间要长得多。他们还使用 Llama 3 进行了实验,并慷慨地分享了他们的日志。

在本文的最后,让我们将 QLoRA 适配器与基础模型合并,并将其推送到Hugging Face Hub。

# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()
# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)
# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()
model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

现在,我们完成了对 Llama 3 的快速微调:mlabonne/OrpoLlama-3-8B。你可以使用这个 "Hugging Face Space"来体验一下。正如 W&B 曲线所示,虽然模型训练不足,但我还是使用 LLM AutoEval 在 Nous 的基准套件上进行了一些评估。

17

我们的 ORPO 微调实际上非常不错,在每个基准测试中都提高了基本型号的性能。这是令人鼓舞的,很可能意味着对整个 40k 样本进行微调会产生很好的结果。

对于开源社区来说,这是一个激动人心的时刻,越来越多的高质量开放重量模型被发布。闭源模型和开放重量模型之间的差距正在慢慢缩小,而微调是为你的使用案例获得最佳性能的重要工具。

18

结论

在本文中,我们介绍了 ORPO 算法,并解释了它如何将 SFT 和偏好校准阶段统一为一个过程。然后,我们使用 TRL 在自定义偏好数据集上对 Llama 3 8B 模型进行了微调。最终的模型显示了令人鼓舞的结果,并凸显了 ORPO 作为一种新的微调范例的潜力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1641696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

8.MyBatis 操作数据库(进阶)

文章目录 1.动态SQL插入1.1使用注解方式插入数据1.2使用xml方式插入数据1.3何时用注解何时用xml?1.4使用SQL查询中有多个and时,如何自动去除多余and1.4.1方法一:删除and之后的代码如图所示,再次运行1.4.2方法二:加上tr…

C语言——文件相关操作

2.什么是文件 3.文件的打开和关闭 4.文件的顺序读写 5.文件的随机读写 6.文本文件和二进制文件 7.文件读取结束的判定 8.文件缓冲区 一、文件相关介绍 1、为什么使用文件 文件用于永久存储数据。通过使用文件,我们可以在程序关闭后保存数据,以便将来…

Springboot图片上传【本地+oss】

文章目录 1 前端组件页面2 本地上传3 上传到阿里云oss3.1申请开通账号&#xff0c;做好先导准备3.2 开始使用 1 前端组件页面 使用的VueElement组件 在线cdn引入&#xff1a; <script src"https://cdn.bootcdn.net/ajax/libs/vue/2.7.16/vue.js"></script&…

Simulink|【免费】虚拟同步发电机(VSG)惯量阻尼自适应控制仿真模型

目录 主要内容 仿真模型要点 2.1 整体仿真模型 2.2 电压电流双闭环模块 2.3 SVPWM调制策略 2.4 无功电压模块 2.5 自适应控制策略及算法 部分结果 下载链接 主要内容 该模型为simulink仿真模型&#xff0c;主要实现的内容如下&#xff1a; 随着风力发电、…

免费APP分发平台 - 一个指南和解析

数字化时代的APP分发平台 随着数字化进程的加速免费APP分发平台 - 一个指南和解析&#xff0c;移动应用&#xff08;APP&#xff09;市场正迅速扩大。在这个充满竞争的市场中免费APP分发平台 - 一个指南和解析&#xff0c;一个优秀的APP分发平台能够帮助开发者和商家更有效地触…

用keras识别狗狗

一、需求场景 从照片从识别出狗狗 from keras.applications.resnet50 import ResNet50 from keras.preprocessing import image from keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np# 加载预训练的ResNet50模型 model ResNet5…

网络知识点之—QoS

QoS&#xff08;Quality of Service&#xff0c;服务质量&#xff09;指一个网络能够利用各种基础技术&#xff0c;为指定的网络通信提供更好的服务能力&#xff0c;是网络的一种安全机制&#xff0c; 是用来解决网络延迟和阻塞等问题的一种技术。QoS的保证对于容量有限的网络来…

【matlab基础知识】(三)二维曲线绘制plot

x[-pi:0.0001:pi]; 选择较小步距 ysin(tan(x))-tan(sin(x));plot(x,y) 条件和函数值做一个点乘 x[-2:0.02:2];y1.1*sign(x).*(abs(x)>1.1)x.*(abs(x)<1.1);plot(x,y) 颜色&#xff0c;线形&#xff0c;曲线上的标志 由于0.01cosx波动太小&#xff0c;所以plotyy绘制多…

C语言 | Leetcode C语言题解之第64题最小路径和

题目&#xff1a; 题解&#xff1a; int minPathSum(int** grid, int gridSize, int* gridColSize) {int rows gridSize, columns gridColSize[0];if (rows 0 || columns 0) {return 0;}int dp[rows][columns];dp[0][0] grid[0][0];for (int i 1; i < rows; i) {dp[i…

【吃透Java手写】- Spring(上)-启动-扫描-依赖注入-初始化-后置处理器

【吃透Java手写】Spring&#xff08;上&#xff09;启动-扫描-依赖注入-初始化-后置处理器 1 准备工作1.1 创建自己的Spring容器类1.2 创建自己的配置类 ComponentScan1.3 ComponentScan1.3.1 Retention1.3.2 Target 1.4 用户类UserService Component1.5 Component1.6 测试类 2…

STM32——WWDG(窗口看门狗)

技术笔记&#xff01; 1.WWDG&#xff08;窗口看门狗&#xff09;简介 本质&#xff1a;能产生系统复位信号和提前唤醒中断的计数器。 特性&#xff1a; 递减的计数器&#xff1b; 当递减计数器值从 0x40减到0x3F时复位&#xff08;即T6位跳变到0&#xff09;&#xff1b; …

GPT-ArcGIS数据处理、空间分析、可视化及多案例综合应用

在数字化和智能化的浪潮中&#xff0c;GIS&#xff08;地理信息系统&#xff09;和GPT&#xff08;生成式预训练模型&#xff09;的结合正日益成为推动科研、城市规划、环境监测等领域发展的关键技术。GIS以其强大的空间数据处理、先进的空间分析工具、灵活的地图制作与可视化能…

OpenCV 实现重新映射(53)

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV 实现霍夫圆变换(52) 下一篇 :OpenCV实现仿射变换(54) 目标 在本教程中&#xff0c;您将学习如何&#xff1a; 一个。使用 OpenCV 函数 cv&#xff1a;&#xff1a;remap 实现简…

mysql-sql-练习题-4-标记(排名 条件判断)

标记 标记找规律连续登录2-7天用户建表排名找规律 最大连胜次数建表多次排名 找规律输出更多数据 标记计数 百分比 标记找规律 连续登录2-7天用户 建表 create table continuous_login(user_id1 integer comment 用户id,date_login date comment 登陆日期 ) comment 用户登录…

一加12/11/10/Ace2/Ace3手机上锁回锁BL无限重启黑屏9008模式救砖

一加12/11/10/Ace2/Ace3手机官方都支持解锁BL&#xff0c;搞机的用户也比较多&#xff0c;相对于其他品牌来说&#xff0c;并没有做出限制&#xff0c;这也可能是搞机党最后的救命稻草。而厌倦了root搞机的用户&#xff0c;就习惯性回锁BL&#xff0c;希望彻底变回官方原来的样…

约瑟夫问题新解法

前言 又碰到了约瑟夫问题&#xff0c;这样的题目本来用环形链表模拟的话就能做出来。然而&#xff0c;最近新学习了一种做法&#xff0c;实在是有点震惊到我了。无论是思路上&#xff0c;还是代码量上&#xff0c;都是那么的精彩。就想也震惊一下其他人。谁能想到原来模拟出来四…

Go-变量

可以理解为一个昵称 以后这个昵称就代指这些信息 var sg string "czy" 声明赋值 package mainimport "fmt"func main() {var sg string "陈政洋"fmt.Println(sg)var age int 73fmt.Println(age)var flag bool truefmt.Println(flag) } …

【JVM】内存调优——内存泄漏、内存溢出

内存调优 什么是内存泄漏、内存泄漏&#xff1f; 内存泄漏&#xff1a;在Java中如果不再使用一个对象&#xff0c;但是该对象依然在GC ROOT的引用链上&#xff0c;这个对象就不会被垃圾回收器回收。内存溢出&#xff1a;内存的使用量超过了Java虚拟机可以分配的上限&#xff…

ARP欺骗使局域网内设备断网

一、实验准备 kali系统&#xff1a;可使用虚拟机软件模拟 kali虚拟机镜像链接&#xff1a;https://www.kali.org/get-kali/#kali-virtual-machines 注意虚拟机网络适配器采用桥接模式 局域网内存在指定断网的设备 二、实验步骤 打开kali系统命令行&#xff1a;ctrlaltt可快…

栈的表达式求值中的应用——逆波兰表达式求值+中缀表达式转后缀表达式

文章目录 1. 逆波兰表达式&#xff08;后缀表达式&#xff09;求值思路讲解AC代码 2. 中缀表达式转后缀表达式分析方法总结 3. 中缀表达式求值 1. 逆波兰表达式&#xff08;后缀表达式&#xff09;求值 链接: link 这道题目叫做逆波兰表达式求值&#xff0c;那什么是逆波兰表…