从 ColossalChat 到 DeepSpeedChat, RLHF的应用及优化
原创 紫气东来 收录于合集#LLMs8个
作者:紫气东来
项目地址:https://zhuanlan.zhihu.com/p/621391363
一、深入理解 ColossalChat
在上一期
NLP(九):LLaMA, Alpaca, ColossalChat 系列模型研究 - 知乎
https://zhuanlan.zhihu.com/p/618695885
(zhihu.com) 的最后部分简要介绍了ColossalChat 的训练过程。下面将逐步解析训练的3个过程。
1.1 Stage 1:Supervised instructs tuning
该阶段训练采用 InstructionWild 数据集中的 "instinwild_en.json" ,总数据量为52K,其数据格式如下所示:
InstructionWild 数据集
https://github.com/XueFuzhao/InstructionWild/tree/main/data
{"instruction": "Provide a list of the top 10 most popular mobile games in Asia","input": "","output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved","id": 0}
SFT 的训练逻辑如下图所示,即是典型的 Instruction 的做法 :

为节约GPU资源,使用 GPT2-XL 模型训练,命令如下:
torchrun --standalone --nproc_per_node=8 train_sft.py \--pretrain "gpt2-xl" \--model 'gpt2' \--strategy colossalai_zero2 \--log_interval 10 \--save_path './trained_sft_gpt2-xl' \--dataset 'dataset/instinwild_en.json' \--batch_size 2 \--accimulation_steps 8 \--lr 2e-5 \--max_datasets_size 51200 \--max_epochs 2
训练过程的 lr 和 loss 的记录如下图所示:

1.2 Stage 2:Training reward model
RM 模型训练使用 Anthropic/hh-rlhf 数据集,训练集数据总量为161K,其数据样本如下所示:
Anthropic/hh-rlhf 数据集
https://huggingface.co/datasets/Anthropic/hh-rlhf
| chosen (string) | rejected (string) |
|---|---|
| " Human: What kind of noises did dinosaurs make? Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be Human: yes they did Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. Human: you cant read Assistant: You can read?" | " Human: What kind of noises did dinosaurs make? Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be Human: yes they did Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. Human: you cant read Assistant: there’s a lot of stuff humans don’t know" |
RM 的训练逻辑如下图所示:

训练的命令如下:
torchrun --standalone --nproc_per_node=8 train_reward_model.py \--pretrain 'gpt2-xl' \--model 'gpt2' \--strategy colossalai_zero2 \--loss_fn 'log_sig'\--save_path 'trained_rm_gpt2-xl.pt'\--dataset 'Anthropic/hh-rlhf'\
在训练过程中,除记录loss外,还需要记录在验证集上的 dist(chosen_reward-reject_reward 的均值) 和 acc ,结果如下所示:

1.3 Stage 3:Training model with reinforcement learning by human feedback
Stage3 是通过阶段2训练出来的奖励函数微调出一个RL模型,微调过程中通过PPO算法限制RL模型的参数更新范围(以阶段1的SFT模型的策略为参考基准,PPO算法避免与基线模型SFT的策略偏离过远)。其整体过程如下图所示,该过程又可简单分为2个部分。



使用前2阶段训练得到的SFT和RM模型,第3阶段的完整训练如下所示:
torchrun --standalone --nproc_per_node=8 train_prompts.py \--pretrain "trained_sft_gpt2-xl" \--model 'gpt2' \--strategy colossalai_zero2 \--prompt_path "dataset/seed_prompts_en.jsonl" \--pretrain_dataset 'dataset/instinwild_en.json' \--rm_model 'gpt2' \--rm_pretrain "gpt2-xl" \--rm_path 'trained_rm_gpt2-xl.pt' \--train_batch_size 1 \--experience_batch_size 1 \--num_episodes 20 \--max_epochs 20
二、体验 DeepSpeedChat
从训练过程上来说,DeepSpeedChat 与 ColossalChat 没有明显,都是过程一致的3个阶段,主要是实现了一些工程上的优化。

为便于与ColossalChat 比较采用同等规模模型,3个阶段可通过一个命令运行
python3 train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --deployment-type single_node
其中3个阶段实际上仍然被分解为3步执行:
deepspeed main.py \ --data_path Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets \ --data_split 2,4,4 \ --model_name_or_path facebook/opt-1.3b \ --per_device_train_batch_size 8 \ --per_device_eval_batch_size 8 \ --max_seq_len 512 \ --learning_rate 9.65e-6 \ --weight_decay 0. \ --num_train_epochs 16 \ --gradient_accumulation_steps 1 \ --lr_scheduler_type cosine \ --num_warmup_steps 0 \ --seed 1234 \ --zero_stage 2 \ --deepspeed \ --output_dir "./output" \
第一阶段训练日志过程的部分日志如下:
***** Running training ********** Evaluating perplexity, Epoch 0/16 *****ppl: 4937.2431640625Beginning of Epoch 1/16, Total Micro Batches 920[2023-05-01 17:08:55,881] [INFO] [logging.py:96:log_dist] [Rank 0] step=10, skipped=6, lr=[9.649998241787337e-06, 9.649998241787337e-06], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-01 17:08:56,078] [INFO] [timer.py:199:stop] epoch=0/micro_step=10/global_step=10, RunningAvgSamplesPerSec=69.28514057779762, CurrSamplesPerSec=62.312571874000376, MemAllocated=4.98GB, MaxMemAllocated=23.97GB[2023-05-01 17:09:06,144] [INFO] [logging.py:96:log_dist] [Rank 0] step=20, skipped=6, lr=[9.649978461909591e-06, 9.649978461909591e-06], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-01 17:09:06,339] [INFO] [timer.py:199:stop] epoch=0/micro_step=20/global_step=20, RunningAvgSamplesPerSec=65.28289507772782, CurrSamplesPerSec=62.544054882381914, MemAllocated=4.98GB, MaxMemAllocated=23.97GB...***** Evaluating perplexity, Epoch 1/16 *****ppl: 2.0012214183807373Beginning of Epoch 2/16, Total Micro Batches 920[2023-05-01 17:24:53,358] [INFO] [logging.py:96:log_dist] [Rank 0] step=930, skipped=13, lr=[9.55789070120902e-06, 9.55789070120902e-06], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-01 17:24:53,558] [INFO] [timer.py:199:stop] epoch=1/micro_step=10/global_step=930, RunningAvgSamplesPerSec=62.532406953883935, CurrSamplesPerSec=62.11747630869141, MemAllocated=4.98GB, MaxMemAllocated=23.97GB[2023-05-01 17:25:03,622] [INFO] [logging.py:96:log_dist] [Rank 0] step=940, skipped=13, lr=[9.555877413047903e-06, 9.555877413047903e-06], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-01 17:25:03,819] [INFO] [timer.py:199:stop] epoch=1/micro_step=20/global_step=940, RunningAvgSamplesPerSec=62.53102982115688, CurrSamplesPerSec=62.38918457600946, MemAllocated=4.98GB, MaxMemAllocated=23.97GB...***** Evaluating perplexity, Epoch 15/16 *****ppl: 1.7830698490142822Beginning of Epoch 16/16, Total Micro Batches 920[2023-05-01 21:07:37,315] [INFO] [logging.py:96:log_dist] [Rank 0] step=13810, skipped=265, lr=[1.5092112560532933e-07, 1.5092112560532933e-07], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-01 21:07:37,514] [INFO] [timer.py:199:stop] epoch=15/micro_step=10/global_step=13810, RunningAvgSamplesPerSec=62.66213404288616, CurrSamplesPerSec=62.30927408141939, MemAllocated=4.98GB, MaxMemAllocated=23.97GB[2023-05-01 21:07:47,568] [INFO] [logging.py:96:log_dist] [Rank 0] step=13820, skipped=265, lr=[1.4837637890662103e-07, 1.4837637890662103e-07], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-01 21:07:47,765] [INFO] [timer.py:199:stop] epoch=15/micro_step=20/global_step=13820, RunningAvgSamplesPerSec=62.66198725001066, CurrSamplesPerSec=62.56049691270701, MemAllocated=4.98GB, MaxMemAllocated=23.97GB...***** Evaluating perplexity, Epoch 16/16 *****ppl: 1.780816674232483saving the final model ...
deepspeed main.py \--data_path Dahoas/rm-static Dahoas/full-hh-rlhf Dahoas/synthetic-instruct-gptj-pairwise yitingxie/rlhf-reward-datasets \--data_split 2,4,4 \--model_name_or_path facebook/opt-350m \--num_padding_at_beginning 1 \--per_device_train_batch_size 4 \--per_device_eval_batch_size 4 \--max_seq_len 512 \--learning_rate 5e-5 \--weight_decay 0.1 \--num_train_epochs 1 \--disable_dropout \--gradient_accumulation_steps 1 \--lr_scheduler_type cosine \--num_warmup_steps 0 \--seed 1234 \--zero_stage 0 \--deepspeed \--output_dir "./output" \
第二阶段训练日志过程的部分日志如下:
***** Running training ********** Evaluating reward, Epoch 0/1 *****chosen_last_scores (higher is better) : 2.8095741271972656, acc (higher is better) : 0.4898989498615265Beginning of Epoch 1/1, Total Micro Batches 3680...[2023-05-02 17:21:25,830] [INFO] [logging.py:96:log_dist] [Rank 0] step=10, skipped=7, lr=[4.999991801084829e-05, 4.999991801084829e-05], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-02 17:21:25,849] [INFO] [timer.py:199:stop] epoch=0/micro_step=10/global_step=10, RunningAvgSamplesPerSec=100.85116912167754, CurrSamplesPerSec=92.62450760911271, MemAllocated=4.32GB, MaxMemAllocated=12.79GB[2023-05-02 17:21:29,272] [INFO] [logging.py:96:log_dist] [Rank 0] step=20, skipped=7, lr=[4.999846044088921e-05, 4.999846044088921e-05], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-02 17:21:29,291] [INFO] [timer.py:199:stop] epoch=0/micro_step=20/global_step=20, RunningAvgSamplesPerSec=96.39672859406251, CurrSamplesPerSec=93.42591718628839, MemAllocated=4.32GB, MaxMemAllocated=12.79GB[2023-05-02 17:21:32,713] [INFO] [logging.py:96:log_dist] [Rank 0] step=30, skipped=7, lr=[4.9995181012051625e-05, 4.9995181012051625e-05], mom=[(0.9, 0.95), (0.9, 0.95)]...[2023-05-02 17:28:40,660] [INFO] [logging.py:96:log_dist] [Rank 0] step=1270, skipped=17, lr=[3.701016326089881e-05, 3.701016326089881e-05], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-02 17:28:40,679] [INFO] [timer.py:199:stop] epoch=0/micro_step=1270/global_step=1270, RunningAvgSamplesPerSec=92.90538668585853, CurrSamplesPerSec=92.88912634851383, MemAllocated=4.32GB, MaxMemAllocated=12.79GB[2023-05-02 17:28:44,120] [INFO] [logging.py:96:log_dist] [Rank 0] step=1280, skipped=17, lr=[3.682254575425273e-05, 3.682254575425273e-05], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-02 17:28:44,139] [INFO] [timer.py:199:stop] epoch=0/micro_step=1280/global_step=1280, RunningAvgSamplesPerSec=92.90310674484337, CurrSamplesPerSec=92.82032648798923, MemAllocated=4.32GB, MaxMemAllocated=12.79GB...[2023-05-02 17:42:28,021] [INFO] [logging.py:96:log_dist] [Rank 0] step=3660, skipped=29, lr=[2.1869706348343066e-08, 2.1869706348343066e-08], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-02 17:42:28,040] [INFO] [timer.py:199:stop] epoch=0/micro_step=3660/global_step=3660, RunningAvgSamplesPerSec=92.68522051842969, CurrSamplesPerSec=93.28319595223864, MemAllocated=4.32GB, MaxMemAllocated=12.79GB[2023-05-02 17:42:31,478] [INFO] [logging.py:96:log_dist] [Rank 0] step=3670, skipped=29, lr=[1.385489430420217e-08, 1.385489430420217e-08], mom=[(0.9, 0.95), (0.9, 0.95)][2023-05-02 17:42:31,497] [INFO] [timer.py:199:stop] epoch=0/micro_step=3670/global_step=3670, RunningAvgSamplesPerSec=92.68524203482882, CurrSamplesPerSec=93.03399659243877, MemAllocated=4.32GB, MaxMemAllocated=12.79GB...Epoch 1/1 with loss inf***** Evaluating reward, Epoch 1/1 *****chosen_last_scores (higher is better) : -0.4733814597129822, acc (higher is better) : 0.6717171669006348saving model ...
Actor_Lr=9.65e-6Critic_Lr=5e-6deepspeed --master_port 12346 main.py \--data_path Dahoas/rm-static \--data_split 2,4,4 \--actor_model_name_or_path 'output/actor-models/1.3b' \--critic_model_name_or_path 'output/reward-models/350m' \--num_padding_at_beginning 1 \--per_device_train_batch_size 4 \--per_device_mini_train_batch_size 4 \--generation_batch_numbers 1 \--ppo_epochs 1 \--max_answer_seq_len 256 \--max_prompt_seq_len 256 \--actor_learning_rate ${Actor_Lr} \--critic_learning_rate ${Critic_Lr} \--num_train_epochs 1 \--lr_scheduler_type cosine \--gradient_accumulation_steps 1 \--disable_actor_dropout \--num_warmup_steps 100 \--deepspeed --seed 1234 \--enable_hybrid_engine \--actor_zero_stage 2 \--critic_zero_stage 2 \--enable_ema \--output_dir 'output'
第三阶段训练日志过程的部分日志如下:
。。.************************[start] Initializing Actor Model [start] *************************huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...To disable this warning, you can either:- Avoid using `tokenizers` before the fork if possible- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)...*******************[end] Initialized Ref Model [end] (duration: 23.45s)********************************************[start] Initializing EMA Model [start] **************************Using /root/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...No modifications detected for re-loaded extension module utils, skipping build step...Loading extension module utils...Time to load utils op: 0.0008325576782226562 seconds...*******************[end] Initialized EMA Model [end] (duration: 20.59s)*******************************************[start] Initializing Critic Model [start] ************************Using /root/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...No modifications detected for re-loaded extension module utils, skipping build step...Loading extension module utils...Time to load utils op: 0.0007164478302001953 seconds...*****************[end] Initialized Reward Model [end] (duration: 14.60s)*********************** Running training *****Beginning of Epoch 1/1, Total Generation Batches 954epoch: 0|step: 0|ppo_ep: 1|act_loss: 0.039031982421875|cri_loss: 0.0031604766845703125|unsuper_loss: 0.0average reward score: -2.208984375-------------------------------------------------------------------------------------|E2E latency=4.19s |Gather latency=0.00s (0.00%) |Generate time=2.82s (67.47%) |Training time=1.13s (27.03%) |Others=0.23 (5.50%)|CurSamplesPerSec=7.65 |AvgSamplesPerSec=7.65[2023-05-02 17:46:36,487] [INFO] [loss_scaler.py:181:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, reducing to 32768[2023-05-02 17:46:36,695] [INFO] [loss_scaler.py:181:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, reducing to 32768epoch: 0|step: 1|ppo_ep: 1|act_loss: 0.041168212890625|cri_loss: 0.036041259765625|unsuper_loss: 0.0average reward score: -2.3125-------------------------------------------------------------------------------------...-------------------------------------------------------------------------------------|E2E latency=2.90s |Gather latency=0.00s (0.00%) |Generate time=1.56s (53.92%) |Training time=1.04s (36.09%) |Others=0.29 (10.00%)|CurSamplesPerSec=11.05 |AvgSamplesPerSec=11.06epoch: 0|step: 36|ppo_ep: 1|act_loss: 0.2275390625|cri_loss: 0.0435791015625|unsuper_loss: 0.0average reward score: -1.783203125-------------------------------------------------------------------------------------|E2E latency=2.89s |Gather latency=0.00s (0.00%) |Generate time=1.56s (53.98%) |Training time=1.04s (35.96%) |Others=0.29 (10.06%)|CurSamplesPerSec=11.06 |AvgSamplesPerSec=11.06epoch: 0|step: 37|ppo_ep: 1|act_loss: -0.06414794921875|cri_loss: 0.0183868408203125|unsuper_loss: 0.0average reward score: -2.1171875-------------------------------------------------------------------------------------|E2E latency=2.89s |Gather latency=0.00s (0.00%) |Generate time=1.56s (54.05%) |Training time=1.04s (35.97%) |Others=0.29 (9.99%)|CurSamplesPerSec=11.08 |AvgSamplesPerSec=11.06epoch: 0|step: 38|ppo_ep: 1|act_loss: -0.203857421875|cri_loss: 0.043121337890625|unsuper_loss: 0.0average reward score: -1.763671875...-------------------------------------------------------------------------------------|E2E latency=2.89s |Gather latency=0.00s (0.00%) |Generate time=1.56s (54.05%) |Training time=1.04s (35.92%) |Others=0.29 (10.04%)|CurSamplesPerSec=11.07 |AvgSamplesPerSec=11.07epoch: 0|step: 951|ppo_ep: 1|act_loss: 0.041351318359375|cri_loss: 0.07244873046875|unsuper_loss: 0.0average reward score: -1.9130859375-------------------------------------------------------------------------------------|E2E latency=2.89s |Gather latency=0.00s (0.00%) |Generate time=1.56s (54.11%) |Training time=1.03s (35.84%) |Others=0.29 (10.05%)|CurSamplesPerSec=11.09 |AvgSamplesPerSec=11.07epoch: 0|step: 952|ppo_ep: 1|act_loss: -0.287109375|cri_loss: 0.333251953125|unsuper_loss: 0.0average reward score: -1.8779296875-------------------------------------------------------------------------------------|E2E latency=2.89s |Gather latency=0.00s (0.00%) |Generate time=1.56s (54.06%) |Training time=1.04s (35.88%) |Others=0.29 (10.06%)|CurSamplesPerSec=11.08 |AvgSamplesPerSec=11.07[2023-05-02 18:32:26,992] [INFO] [loss_scaler.py:181:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 8192, reducing to 4096epoch: 0|step: 953|ppo_ep: 1|act_loss: -0.48193359375|cri_loss: 1.1884765625|unsuper_loss: 0.0average reward score: -2.115234375-------------------------------------------------------------------------------------saving model ...
参考资料
[1] https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatlossalAI/tree/main/applications/Chat
[2] https://https://blog.csdn.net/v_JULY_v/article/details/129996493article/details/129996493
[3] https://jonathan-hui.medium.com/rl-proximal-policy-optimization-ppo-explained-77f014ec3f12/rl-proximal-policy-optimization-ppo-explained-77f014ec3f12
[4] https://huggingface.co/blog/deep-rl-ppo
[5] https://colossalai.org/docs/features/zero_with_chunk/tures/zero_with_chunk/
[6] https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat
-
-
部分1: Make Experience ,利用 SFT 、Actor、RM、Critic模型计算生成 Experience 存入 buffer 中,该部分所有模型都进行前向推理。
-
部分2: 利用 Experience 计算价值损失(value loss)和策略损失(policy loss),并更新参数。
-
第一步:SFT 训练
-
第二步:Reward Model 训练
-
第三步:RLHF 训练
-



















