FlashAttention的OOM排查:为什么显存够了还是报内存不足?
之前有个团队在昇腾NPU上跑Llama-2-7B模型是FP16权重seq_len4096。他们算了算显存模型权重13.5GB 激活值4GB KV Cache 4GB 21.5GB昇腾910有32GB显存绰绰有余。结果一跑就报OOMOut Of Memory。他们懵了明明只用了21.5GB32GB显存怎么就满了我帮他们排查了一下发现问题出在FlashAttention的内存分配策略——它不是一次性分配所有显存而是分块分配。每个分块在SRAM里处理完之后要暂时存一个中间结果这个中间结果会占用额外的显存。今天把这个内存分配的问题讲清楚帮你排查OOM。先打个比方厨房的台面空间想象你在厨房做饭台面空间SRAM有限。你要做一道复杂的菜需要很多步骤。标准Attention的做法先把所有食材摆在台面上一次性分配所有显存一次做完最后清理台面。FlashAttention的做法每次只拿一小块食材到台面上分块分配做完这一步放到一边临时存储再拿下一块。这样台面不用一直占满。问题在哪FlashAttention的分块策略虽然省了总显存但峰值显存可能更高——因为每个分块处理完的临时结果还在显存里要等所有分块处理完才能释放。FlashAttention的内存分配到底怎么回事FlashAttention在昇腾NPU上的内存分配分三层第一层模型权重静态分配模型权重是最稳定的显存占用在整个推理/训练过程中都存在。Llama-2-7B FP16的权重分布 QKV投影4096 × (3 × 4096) × 2 302 MB 输出投影4096 × 4096 × 2 32 MB FFN门控4 × 4096 × 11008 × 2 361 MB FFN降维11008 × 4096 × 2 90 MB 词表投影32000 × 4096 × 2 262 MB 位置编码513 × 4096 × 2 4 MB LayerNorm4096 × 2 × 2 × 2 0.06 MB 总计~1 GB32层堆叠 × 1 GB ~13.5 GB跟之前算的一致。第二层KV Cache动态分配KV Cache是推理时动态分配的。每个token生成完它的K和V要存下来供后续token的Attention用。KV Cache的计算 每个token的KV大小 num_kv_heads × head_dim × 2KV 32 × 128 × 2 8 KB seq_len4096的KV Cache 4096 × 8 KB 32 MB单层 32层的KV Cache 32 × 32 MB **1 GB**第三层FlashAttention的中间结果最容易OOM的地方FlashAttention在分块处理的时候会产生一些中间结果。这些中间结果不会立即释放要等整个Attention层处理完才能释放。FlashAttention的中间结果每层 Q分块缓冲block_size × head_dim × 2 × num_heads 128 × 128 × 2 × 32 1 MB K分块缓冲block_size × head_dim × 2 × num_kv_heads 128 × 128 × 2 × 32 1 MB V分块缓冲同上 1 MB 输出缓冲block_size × head_dim × 2 × num_heads 1 MB 在线Softmax状态m和lblock_size × 4 × 2 1 KB可忽略 每层的中间结果总计~4 MB 32层的中间结果32 × 4 **128 MB**等等128MB看起来不多啊为什么会OOM真正的问题Gradient Checkpointing的中间结果问题出在Gradient Checkpointing激活重计算——训练的时候为了省显存会重计算前向传播的中间结果。但FlashAttention的中间结果不在这个策略里# Gradient Checkpointing的配置modelGradientCheckpointing(model,checkpoint_ratio0.5)# FlashAttention的中间结果不在checkpoint策略里# 这些中间结果会一直占用显存不会被释放FlashAttention的中间结果不在Gradient Checkpointing的保护范围内所以会占用额外的显存。OOM的常见原因原因1batch_size太大batch_size太大会导致KV Cache显存占用暴涨。batch_size1KV Cache 1 GB batch_size4KV Cache 4 GB线性增长 batch_size16KV Cache 16 GB batch_size32KV Cache 32 GB爆炸原因2seq_len太长seq_len太长会导致KV Cache和Attention矩阵的中间结果都变大。seq_len2048KV Cache 512 MB中间结果 64 MB seq_len4096KV Cache 1 GB中间结果 128 MB seq_len8192KV Cache 2 GB中间结果 256 MB seq_len16384KV Cache 4 GB中间结果 512 MB原因3没有开PagedAttentionPagedAttention能把KV Cache的显存利用率从34%提升到91%。不开PagedAttention同样的显存能跑的batch_size更小。原因4混合精度配置不对如果开了FP16训练但BF16推理或者反过来昇腾NPU要做额外的精度转换占用额外显存。OOM排查清单你的FlashAttention报OOM按这个清单查defdiagnose_oom():FlashAttention OOM诊断# 1. 检查模型权重大小total_paramssum(p.numel()forpinmodel.parameters())weight_memtotal_params*2/(1024**3)# FP16print(f模型权重显存{weight_mem:.2f}GB)# 2. 检查KV Cache大小kv_mem_per_layerseq_len*num_kv_heads*head_dim*2*2/(1024**2)# FP16, MBtotal_kv_memkv_mem_per_layer*num_layersprint(fKV Cache显存单层{kv_mem_per_layer:.2f}MB)print(fKV Cache显存总{total_kv_mem:.2f}GB)# 3. 检查中间结果大小intermediate_mem_per_layer(block_size*head_dim*2*num_heads*4/(1024**2)# QKV输出缓冲)total_intermediateintermediate_mem_per_layer*num_layersprint(f中间结果显存单层{intermediate_mem_per_layer:.2f}MB)print(f中间结果显存总{total_intermediate:.2f}GB)# 4. 计算总显存total_memweight_memtotal_kv_memtotal_intermediateprint(f\n估算总显存{total_mem:.2f}GB)print(f可用显存{torch.npu.get_device_properties(0).total_memory/(1024**3):.2f}GB)# 5. 判断iftotal_memtorch.npu.get_device_properties(0).total_memory/(1024**3):print(\n❌ 显存不足)print(建议)print( - 减小batch_size)print( - 减小seq_len)print( - 开PagedAttention)print( - 用INT8 KV Cache量化)else:print(f\n✅ 显存估算足够剩余{torch.npu.get_device_properties(0).total_memory/(1024**3)-total_mem:.2f}GB)print(实际OOM可能是其他原因内存碎片、驱动问题等)# 运行诊断diagnose_oom()解决OOM的方法方法1开PagedAttentionPagedAttention能把KV Cache的显存利用率从34%提升到91%。# vLLM配置python-m vllm.entrypoints.openai.api_server \--model./models/Llama-2-7b-chat-hf \--enable-flash-attn \--use-paged-attention \# 开PagedAttention--max-num-seqs32方法2用INT8 KV Cache量化INT8量化能把KV Cache的显存减半。# vLLM配置python-m vllm.entrypoints.openai.api_server \--model./models/Llama-2-7b-chat-hf \--enable-flash-attn \--kv-cache-dtype int8 \# INT8 KV Cache--max-num-seqs32方法3减小batch_sizebatch_size减半KV Cache显存减半。方法4用Gradient Checkpointing训练训练时开Gradient Checkpointing省前向传播的激活值显存。modelGradientCheckpointing(model,checkpoint_ratio0.5)总结一下FlashAttention的OOM原因batch_size太大→ KV Cache显存爆炸seq_len太长→ KV Cache和中间结果都变大没开PagedAttention→ KV Cache显存利用率低混合精度配置不对→ 额外的精度转换占用显存估算公式总显存 模型权重 batch_size × KV_Cache_per_token × seq_len 中间结果解决OOM的方法开PagedAttention用INT8 KV Cache量化减小batch_size开Gradient Checkpointing训练代码和文档https://atomgit.com/cann/ops-transformer
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2639132.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!