目录
一.引言
二.常用参数
◆ ModelArguments
◆ DataArguments
◆ TrainingArguments
◆ GeneratingArguments
三.代码实现
◆ Python 代码
◆ Shell 代码
四.总结
一.引言
LLM 相关训练框架都会引入 ModelArguments、DataArguments、TrainingArguments、GeneratingArguments 并通过 Transformer.HfArgumentParser 进行整合,实现了两行代码处理训练全程的参数问题。
DataArguments - 数据集参数
TrainingArguments - 训练参数
GeneratingArguments - 生成参数
二.常用参数
◆ ModelArguments
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
ModelArguments 主要存储模型加载与配置的相关参数,一般还有以下参数,大家可以自定义:
| 参数名称 | 默认 | 类型 | 含义 |
| model_name_or_path | None | str | 模型地址或名称 |
| cache_dir | None | str | 缓存地址 |
| use_fast_tokenizer | False | bool | 使用快速 tokenizer |
| padding_side | left | str | 模型 pad 选择 |
| quantization_bit | None | int | 量化 bit 选择 |
| compute_type | None | torch.dtype | 模型参数类型 |
| checkpoint_dir | None | str | 微调参数地址 |
| mode | None | str | reward、lora |
| plot_loss | False | bool | 打印训练 Loss |
◆ DataArguments
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
DataArguments 主要负责数据集相关参数,数据集通过 dataset 构成,通常包含下述参数:
| 参数名称 | 默认 | 类型 | 含义 |
| data_path | None | str | 数据集地址 |
| process_num | None | int | 并行处理 |
| max_source_length | 512 | int | source 最大长度 |
| max_target_length | 512 | int | target 最大长度 |
| max_samples | None | int | 最大样本数 |
| ignore_pad_token | None | int | loss 计算是否忽略 |
| prompt_template | None | str | 样本生成 prompt 模板 |
◆ TrainingArguments
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = field(default=False)
output_dir: str = field(default="")
TrainingArguments 主要存储模型微调,训练相关的参数:
| 参数名称 | 默认 | 类型 | 含义 |
| finetuning_type | lora | str | 微调类型 |
| lora_target | q_proj,v_proj | str | 微调 Layer |
| lora_rank | 8 | int | lora 降维维度 |
| lora_alpha | 32.0 | float | lora 微调比例因子 |
| lora_dropout | 0.1 | float | dropout 比例 |
| num_hidden_layers | 32 | int | Decode 数量 |
| num_layer_trainable | 3 | int | freeze layer 数量 |
| name_module_trainable | mlp | str | freeze 训练层选择 |
| output_dir | None | str | 模型输出地址 |
◆ GeneratingArguments
@dataclass
class GeneratingArguments:
do_sample: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
)
GeneratingArguments 主要负责 model generate 生成的配置:
| 参数名称 | 默认 | 类型 | 含义 |
| do_sample | True | bool | 采样或贪心 |
| temperature | 0.95 | float | 调整下一个 token 的概率 |
| top_p | 0.7 | float | token 概率 top 区间 |
| top_k | 50 | int | token 词库数量 |
| num_beams | 1 | int | beam search 数量 |
| max_length | None | int | 最大生成 token 数 |
| max_new_tokens | 512 | int | 最多新 toekn 生成数 |
| repatition_penalty | 1.0 | float | 重复惩罚 |
| length_penalty | 1.0 | float | 长度惩罚 |
之前单独整理了生成的参数和代码,可以参考: LLM - model batch generate 生成文本
三.代码实现
◆ Python 代码
from typing import Optional
from dataclasses import dataclass, field
import transformers
...
添加上述的 Argument Class
...
if __name__ == '__main__':
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GeneratingArguments))
model_args, data_args, training_args, generate_args = parser.parse_args_into_dataclasses()
print(model_args)
print(data_args)
print(training_args)
print(generate_args)
两行搞定多类参数,参数对应属性使用 args.xxx 调用即可。
◆ Shell 代码
#!/bin/bash
python GetConfigByArgs.py \
--report_to "none" \
--data_path "data/belle_chat_ramdon_10k.json" \
--model_name_or_path "baichuan-inc/Baichuan2-7B-Base" \
--output_dir "output" \
--model_max_length 512 \
--num_train_epochs 4 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 1 \
--save_strategy epoch \
--learning_rate 2e-5 \
--lr_scheduler_type constant \
--adam_beta1 0.9 \
--adam_beta2 0.98 \
--adam_epsilon 1e-8 \
--max_grad_norm 1.0 \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--logging_steps 1 \
--gradient_checkpointing True \
--deepspeed ds_config.json \
--bf16 False \
--tf32 False
通过 -- 传递我们需要的参数即可。
四.总结
这个没啥总结的了,就是觉得写法比较优雅,后面自己的脚本也可以借用。




















