CANN-昇腾NPU长序列训练-128K上下文怎么不OOM
Llama 3 支持 128K 上下文长度。训练时 128K 序列的 Attention 显存是 O(N²)128K × 128K × fp16 32GB 每层32 层 1TB。显然放不下。FlashAttention 把显存从 O(N²) 降到 O(N)但在训练场景下还有额外挑战。FlashAttention 的显存节省标准 Attention Q·K^T [batch, heads, seq, seq] ← 这个矩阵是 O(N²) 128K × 128K × fp16 × 32 heads 32GB/层 FlashAttention 不存 Q·K^T 矩阵分块计算只存 O(N) 的归一化因子 显存 ≈ Q K V O 4 × batch × heads × seq × dim × 2 bytes 128K × 128 × fp16 × 32 heads 1GB/层从 32GB 降到 1GB32 层从 1TB 降到 32GB。这就是 FlashAttention 训练长序列的前提。训练的额外显存激活推理只存 KV Cache。训练要存所有中间激活给 backward 用每层需要存的激活 Q, K, V: 3 × batch × heads × seq × dim Q·K^T 归一化因子: batch × heads × seqFlashAttention 的 O(N) 存储 Attention 输出: batch × heads × seq × dim FFN 中间结果: batch × seq × ff_dim Llama2-7B, seq128K, batch1: Q/K/V: 3 × 32 × 128K × 128 × 2 3GB 归一化因子: 32 × 128K × 4 16MB FFN 中间: 128K × 14336 × 2 3.5GB 每层约 7GB32 层 224GB224GB 的激活显存8 卡 Atlas 800I A2 总共 512GB去掉权重和优化器状态约 80GB剩 432GB 给激活——刚好放得下但没有余量。激活重计算Activation Recomputation用时间换空间forward 不存中间激活backward 需要时重新算一遍。fromtorch_npu.npuimportamp# 完整激活保存快但显存多withamp.autocast(dtypetorch.bfloat16):lossmodel(x)# 选择性激活重计算只重算 Attention 部分O(N²) 的那部分model.gradient_checkpointing_enable(gradient_checkpointing_kwargs{use_reentrant:False})withamp.autocast(dtypetorch.bfloat16):lossmodel(x)选择性重计算的显存节省策略激活显存训练速度全部保存224GB100%选择性重计算80GB85%全部重计算40GB70%选择性重计算只重算 AttentionFlashAttention 的 forward 很快保留 FFN 的中间结果重算代价大。这是 128K 训练的标配。Sequence ParallelismTensor Parallel 只切 head 维度Attention 的 LayerNorm 和残差连接在每个 rank 上重复计算。Sequence Parallelism 把这些操作沿序列维度切分TP: LayerNorm(x) → 每个 rank 算完整的 LayerNorm SP: LayerNorm(x) → 每个 rank 只算 seq/N 的一段 通信 TP每层 2 次 All-Reduce SP每层 2 次 All-Gather Reduce-Scatter通信量相同但显存省 N 倍SP 的 LayerNorm 激活显存从batch × seq × hidden降到batch × seq/N × hidden。8 卡 SP 的 LayerNorm 激存减到 1/8。实际配置Llama2-7B, 128K 序列, 8 卡 Atlas 800I A2fromatbimportTrainingConfig configTrainingConfig(modelmeta-llama/Llama-2-7b-hf,devicesnpu:0,1,2,3,4,5,6,7,tensor_parallel_size4,sequence_parallelTrue,gradient_checkpointingselective,# 选择性重计算micro_batch_size1,accumulation_steps16,max_seq_len131072,)显存分配权重 优化器: 80GB (4卡TP) 激活: 80GB (选择性重计算 SP) KV Cache: 32GB 余量: 320GB320GB 的余量意味着 batch 还能开更大或者序列更长。128K 长序列训练的三板斧FlashAttention 省显存、选择性激活重计算换空间、Sequence Parallel 切序列维度。三个一起上8 卡就能训 128K。仓库在这里https://atomgit.com/cann/ops-transformer
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2638846.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!