大语言模型训练中的显存占用与优化方法简述
在进行大语言模型LLM的微调或预训练时显存VRAM不足通常是首要面临的问题。为了在有限的硬件资源下完成训练了解显存的具体去向以及相应的优化技术是比较基础的工作。从模型训练的流程来看显存的占用主要可以分为两大部分模型状态Model States和剩余耗时产生的中间变量主要是激活值。以下对相关的优化方法做简单的梳理。一、 模型状态的显存占用与 ZeRO 技术模型状态包含了训练过程中最核心的三类数据模型参数Weights、梯度Gradients以及优化器状态Optimizer States。参数与梯度对于一个 1.7B17亿参数的模型如果使用 BF16 或 FP16 精度参数本身约占 3.4GB。在训练过程中系统还需要存储一份同样大小的梯度。优化器状态这是显存占用的“大户”。以常用的 Adam 优化器为例它需要为每个参数记录动量Momentum和方差Variance。如果使用全精度FP32来存储这些状态以保证精度其占用量通常是参数本身的数倍。为了解决这些静态数据的冗余问题微软提出的ZeROZero Redundancy Optimizer技术被广泛应用。它通过将数据切片并分散到多个显卡上来降低单卡负载ZeRO-1仅对优化器状态进行切片每张显卡只负责维护一部分参数的优化器状态。ZeRO-2在 1 的基础上进一步对梯度进行切片。这是目前平衡显存节省与通信效率较好的选择。ZeRO-3对模型参数也进行切片。当某一层需要计算时临时从其他显卡“借”来参数算完即释放。这种方式能最大程度节省显存但显卡间的通信开销会显著增加。二、 激活值的显存占用与重算机制激活值Activations是指模型在“前向传播”过程中每一层神经元计算出的中间结果。与模型参数不同激活值的占用量是动态的它与训练时的批大小Batch Size和序列长度Sequence Length成正比。在处理长文本时激活值的显存占用往往会超过模型参数本身。由于反向传播计算梯度时必须用到这些中间结果因此默认情况下它们必须保留在显存中。目前的优化主流方案是梯度检查点Gradient Checkpointing其逻辑较为简单逻辑在前向传播时不再保存所有层的激活值而是只保留一小部分关键节点的“检查点”。重算当反向传播需要用到被删除的中间值时系统会根据最近的一个检查点重新进行一次前向计算。代价这是一种典型的“以时间换空间”的方法。它能节省大量的显存有时可达 70% 以上但会增加约 33% 的计算时间。三、 激活值的卸载与并行策略除了重算还有一些进阶的手段来处理激活值虽然它们对硬件环境的要求更高激活值卸载Offloading将暂时不用的激活值通过 PCIe 总线搬运到 CPU 内存中需要时再搬回。受限于 PCIe 的带宽这种方法在某些配置下可能会产生较明显的延迟。序列并行Sequence Parallelism将长文本切分成几段分配给不同的显卡分别计算。这属于分布式训练的高级范畴通常需要较快的跨卡互联带宽支持。四、 参数高效微调LoRA的辅助作用在讨论上述底层优化时不得不提LoRALow-Rank Adaptation。严格来说LoRA 改变的是需要更新的参数量。因为它冻结了原始模型的大部分参数只训练极小规模的旁路矩阵这直接导致梯度大幅减少只需要存储少量可训练参数的梯度。优化器状态减少对应的优化器记录也随之减少。虽然 LoRA 不直接改变激活值的计算方式但由于它极大降低了“模型状态”的显存门槛使得我们有更多的空间去增加 Batch Size 或序列长度。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2453758.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!