80 行 PyTorch 从零写 DeepSeek 的 MLA:量一遍 KV cache、踩一遍 absorption,你才会明白 vLLM 为什么要加专用内核
80 行 PyTorch 从零写 DeepSeek 的 MLA量一遍 KV cache、踩一遍 absorption你才会明白 vLLM 为什么要加专用内核我把 DeepSeek V2/V3 的 Multi-head Latent Attention下称 MLA按论文流程在单卡 RTX 3090 上用 80 行 PyTorch 写了出来然后做了三件事第一对比 cache 体积MLA 比同规模 MHA 小 57 倍第二验证 cache 正确性prefilldecode 的输出和一次性 forward 数值对齐到 1e-8第三把这份朴素实现拿去和 MHA 比 latency结果 16k 上下文时它比 MHA 慢了两倍。慢不是 bug慢是在告诉你MLA 从论文到生产引擎之间还隔着一个叫absorption的线性代数技巧而这个技巧之所以能做又是因为 DeepSeek 多写了一条decoupled RoPE分支。这篇文章用实验把这条链路走完先写出能 work 的最小 MLA再量一遍为什么它慢再补一版 absorbed decode把它压回比 MHA 还省的区间。如果你看完 vLLM / SGLang 的 MLA kernel 觉得它“比论文复杂很多”这篇文章是帮你把论文读到 kernel 的那一步。测试平台RTX 3090 (24 GB)PyTorch 2.9.1 CUDA 12.8fp16 forward。所有代码和原始日志在文末给出欢迎把相同脚本在 4090、A100、H100 上跑一遍数据会变但结论的方向不会变。1. MLA 到底省在哪里一张每 token 字节数表讨论 MLA 之前先把对手摆清楚。对于一个 decoder-only 模型KV cache 每层、每 token 的字节数只和 K/V 的形状 dtype 有关注意力形态每层、每 token 字节fp16MHAn_h × d_h K 同规模 V2 × n_h × d_h × 2GQAn_kv × d_h K 同规模 V2 × n_kv × d_h × 2MLA只缓存 c_kv 和 k_R所有 head 共享(d_c d_r) × 2代入 DeepSeek-V2 论文给出的n_h128, d_h128, d_c512, d_r64MLA 每层每 token 只要(51264)×2 1152 字节MHA 等规模要2×128×128×2 65536 字节——57 倍差距。把 60 层、32k 上下文放进公式config MHA/tok GQA8/tok MLA/tok 60L × 32k 总量 DeepSeek-V2 style 64.00 KB 4.00 KB 1.12 KB MHA120 GB GQA7.5 GB MLA2.11 GB Llama-3-70B-like 32.00 KB 4.00 KB 1.12 KB MHA60 GB GQA7.5 GB MLA2.11 GB3090 上实测分配验证了这组数字(d_cd_r)两段直接用 fp16 tensor 开出来L 2048 MHA/layer128.00 MB MLA/layer 2.25 MB ratio56.9x L 8192 MHA/layer512.00 MB MLA/layer 9.00 MB ratio56.9x L32768 MHA/layer 2.00 GB MLA/layer36.00 MB ratio56.9x L65536 MHA/layer 4.00 GB MLA/layer72.00 MB ratio56.9x作者判断 1如果你要横向比较多个注意力方案的“上下文成本”按每层每 token 的字节数算比按论文声明的压缩倍数靠谱得多。同样是“压缩”GQA 把 K/V 共享到 n_kv 个分组MLA 是把 K/V 压进一条共享 latent在 Llama-3-70B 这种已经 GQA-8 的基线上MLA 的相对优势从“57 倍”缩到“3.6 倍”并不是任何模型换 MLA 都能看到论文里的那种差距。2. 用 80 行 PyTorch 把 MLA 写清楚MLA 论文里最容易被忽略的是它的 Q 也走了低秩分解且把 KV 拆成content走低秩 latent和rope走独立 rotary 分支两段。只有理解这个拆法后面 absorption 和 decoupled RoPE 才讲得通。先给结构h ─┐ ┌─► c_q ──► q_c (per head, d_h) ├─ W_DQ/W_DKV ─── ┤ │ ├─► c_q ──► q_r (per head, d_r) ┐ │ ├─► RoPE │ │ ├─ W_KR ────────────────────────► k_r (shared, d_r) ┘ │ └─ W_DKV ─────────► c_kv (shared latent, d_c) │ ├─► W_UK ──► k_c (per head, d_h) └─► W_UV ──► v (per head, d_h) cache 只保留两条线**共享的 c_kv ∈ [B, L, d_c]** 和 **共享的 k_r ∈ [B, L, d_r]**。其余的 Q、K_C、V 都是运行时算出来的。 下面是我这版最小 MLA 的核心代码完整版见附录。变量名尽量沿用论文d_c 是 KV latent 维度d_r 是 rotary 分支维度d_cq 是 Q latent 维度。 python class MLA(nn.Module): def __init__(self, d_model1024, n_h16, d_h64, d_c128, d_r32, d_cq256): super().__init__() self.n_h, self.d_h, self.d_c, self.d_r, self.d_cq n_h, d_h, d_c, d_r, d_cq self.Wdq nn.Linear(d_model, d_cq, biasFalse) self.Wuq nn.Linear(d_cq, n_h * d_h, biasFalse) # content Q self.Wqr nn.Linear(d_cq, n_h * d_r, biasFalse) # rope Q self.Wdkv nn.Linear(d_model, d_c, biasFalse) self.Wuk nn.Linear(d_c, n_h * d_h, biasFalse) # content K self.Wuv nn.Linear(d_c, n_h * d_h, biasFalse) # V self.Wkr nn.Linear(d_model, d_r, biasFalse) # shared K_R self.Wo nn.Linear(n_h * d_h, d_model, biasFalse) def forward(self, h, pastNone): B, L, _ h.shape n_h, d_h, d_c, d_r self.n_h, self.d_h, self.d_c, self.d_r c_kv self.Wdkv(h) # [B, L, d_c] ← cache k_r self.Wkr(h) # [B, L, d_r] ← cache if past is not None: c_kv torch.cat([past[0], c_kv], dim1) k_r torch.cat([past[1], k_r ], dim1) Lt c_kv.shape[1] qz self.Wdq(h) q_c self.Wuq(qz).view(B, L, n_h, d_h) q_r self.Wqr(qz).view(B, L, n_h, d_r) k_c self.Wuk(c_kv).view(B, Lt, n_h, d_h) # 运行时重建 v self.Wuv(c_kv).view(B, Lt, n_h, d_h) # 运行时重建 k_r_h k_r.unsqueeze(2).expand(B, Lt, n_h, d_r) # 所有 head 共享 # decoupled RoPE: 只作用在 d_r 段上 pos_q torch.arange(Lt - L, Lt, deviceh.device) pos_k torch.arange(Lt, deviceh.device) cos_q, sin_q build_rope(pos_q, d_r, deviceh.device, dtypeh.dtype) cos_k, sin_k build_rope(pos_k, d_r, deviceh.device, dtypeh.dtype) q_r apply_rope(q_r, cos_q[None, :, None, :], sin_q[None, :, None, :]) k_r_h apply_rope(k_r_h, cos_k[None, :, None, :], sin_k[None, :, None, :]) q torch.cat([q_c, q_r], dim-1) # [B, L, H, d_hd_r] k torch.cat([k_c, k_r_h], dim-1) # [B, Lt, H, d_hd_r] q q.transpose(1, 2); k k.transpose(1, 2); vv v.transpose(1, 2) o F.scaled_dot_product_attention(q, k, vv, is_causal(past is None)) return self.Wo(o.transpose(1, 2).contiguous().view(B, L, n_h * d_h)), (c_kv, k_r) 正确性先过一遍。把同一段序列分两次输入前面 L-1 做 prefill最后一个 token 带着 cache 进来和整段一次性 forward 对比 cache equivalence test prefill segment max diff: 0.00e00decoded last token max diff: 4.47e-08数值基本在 fp32 的 round-off 噪声上cache 合并逻辑 OK。到这里 MLA 已经能 work 了。 **作者判断 2**很多人复现 MLA 时的第一个坑是把 k_r 和 q_r 也放到 per-head 里去然后忘了 k_r 其实是“所有 head 共享”的——这会让你的 cache 里多存 n_h × d_r 倍的冗余也会让 absorption 没法做。上面 k_r self.Wkr(h)维度 [B, L, d_r]与 n_h 解耦是 DeepSeek 原始设计的关键。 --- ## 3. 先给一个反直觉结果我的 80 行 MLA 比 MHA 还慢 把同一套超参d_model2048, n_h16, d_h128下的 MHA 和上面这版朴素 MLAd_c384, d_r32, d_cq1024塞进 fp16、单层、单卡的 prefilldecode 测试 prefill1024, decode256 MHA prefill 1.65 ms decode/tok0.449 ms cache10.00 MBMLA prefill 2.77 ms decode/tok1.144 ms cache 1.02 MB prefill4096, decode256 MHA prefill 3.86 ms decode/tok0.245 ms cache34.00 MBMLA prefill 4.91 ms decode/tok1.057 ms cache 3.45 MB prefill16384, decode256 MHA prefill 23.22 ms decode/tok0.584 ms cache130.00 MBMLA prefill 46.74 ms decode/tok2.211 ms cache 13.20 MB结论直白**cache 是省了latency 是赔的**。16k 上下文时 MLA 每个 decode token 要 2.2 ms几乎是 MHA 的 4 倍。 为什么看 decode 步骤里 MLA 多干的活 - self.Wuk(c_kv)把整条 c_kv ∈ [B, Lt, d_c] 升维到 [B, Lt, n_h·d_h]这是一个 Lt × d_c × (n_h·d_h) 的 matmul随 Lt 线性增长 - - self.Wuv(c_kv)同上再做一次 - - 之后再把 K/V 展平成 [B, H, Lt, d_h] 做 attention。 也就是说**每产生一个 token你都在把整条历史 cache 做两次 up-projection**。MHA 没这个负担K 和 V 本来就是 [B, Lt, H, d_h]append 一下就能用。 所以 DeepSeek V2 的 paper 并没有骗你。它真实的意思是“cache 能压 57 倍但是朴素实现的 decode kernel 会慢所以我们论文里顺手给了一个 absorption 技巧来把这两次升维吸进 Q 端的常数矩阵”。 **作者判断 3**如果你在博客或知乎看到“把 MLA 换进 MHA 代码里就能加速长上下文”的说法应该警惕。没有 absorption 的 MLA 在 decode 阶段是被带宽 计算两头打的它只在 cache 极端紧张比如 3090/4090 跑 65k 上下文MHA 直接 OOM 的场景下才一定更优。工程团队要么接 vLLM/SGLang 的 MLA kernel要么自己把 absorption 写对才能吃到它的红利。 --- ## 4. 为什么是“decoupled RoPE”一行 einsum 就能看清 论文里 K concat(K_C, K_R) 的拆法第一眼看上去像工程妥协。真正的原因是MLA 的压缩技巧依赖 W_UK 和 Q 之间做矩阵合并absorption而 RoPE 不允许这种合并。 先写两条路径 python # Path A: 朴素做法——先从 c_kv 重建每 head 的 K_C再和 Q_C 算分数 K_C torch.einsum(bld,hdk-blhk, c_kv, W_UK) # [B,L,H,Dh] scores_A torch.einsum(blhk,bmhk-bhlm, Q_C, K_C) # Path B: absorption——把 W_UK 吸到 Q 端直接和 c_kv 打分 Q_abs torch.einsum(blhk,hdk-blhd, Q_C, W_UK) # [B,L,H,Dc] scores_B torch.einsum(blhd,bmd-bhlm, Q_abs, c_kv)两条路径在线性代数上是同一件事实测也是content-path absorption max diff: 2.38e-06误差就是 fp32 的舍入噪声。这意味着decode 步骤可以跳过W_UK矩阵、直接让 Q_abs 去和 cache 打分维度从[B,L,H,Dh] · [B,Lt,H,Dh]变成[B,L,H,Dc] · [B,Lt,Dc]后者只需一份共享 c_kv 而不是 per-head K——这正是 vLLM / SGLang MLA kernel 的真实形态。然后我们把 RoPE 加回 K_C 再看K_C_ropeK_C*rot# 简化的 RoPE 位置因子scores_A_ropetorch.einsum(blhk,bmhk-bhlm,Q_C,K_C_rope)scores_B_rope_brokentorch.einsum(blhd,bmd-bhlm,Q_abs,c_kv)# 还是未旋转的 c_kv此时if we RoPE K_C, absorption path diverges by: 1.03e01误差从1e-6跳到10.3。原因RoPE 的位置因子rot(pos)是和W_UK之后的那个 Dh 维绑在一起的W_UK · rot(pos)不等于rot(pos) · W_UKabsorption 步骤把W_UK预吸到 Q_abs 里之后就再也对不上了。这就是为什么 MLA 必须把 K 拆成两段content 分支是纯线性映射允许做 absorptionrope 分支用很小的 d_rDeepSeek-V2 是 64单独挂在一边大家都 RoPE 一下再拼起来绕开矩阵合并冲突。作者判断 4“decoupled RoPE” 的真正含义不是“多一个额外的 RoPE”而是“用很小的 d_r 换来 W_UK 可以被吸收”。这是 MLA 从“cache 更小”进化到“decode 也不慢”的那把钥匙很多教程在讲 MLA 时把它讲成工程手感其实是数学硬约束。5. 把 absorption 补上decode/tok 从 2.2 ms 回到 1.4 ms光证明 absorption 在数学上等价没意义我把它写成一个只在 decode 时生效的版本MLAAbsorbDecode完整代码见附录。核心差异decode 时不再调用Wuk(c_kv)/Wuv(c_kv)而是把W_UK重排成[H, d_c, d_h]和 Q 做einsum(blhd,hcd-blhc)得到 Q_abscontent 分数 Q_abs · c_kv^Trope 分数 q_r · k_r^T共享 k_r 广播到各 headattention 先在 latent 维得到ctx_latent [B, H, 1, d_c]再用W_UV投到[B, H, 1, d_h]。在相同模型、相同 prefill 后续 256 token decode 的条件下prefill 1024 base-MLA/tok1.071 ms absorbed/tok1.384 ms diff1.53e-05 speedup0.77x prefill 4096 base-MLA/tok1.085 ms absorbed/tok1.343 ms diff7.63e-06 speedup0.81x prefill 16384 base-MLA/tok2.200 ms absorbed/tok1.364 ms diff3.81e-06 speedup1.61x三点值得看correctnessfp16 下和朴素版本最大差 1.5e-5和 fp16 精度同量级。短上下文 absorption 反而更慢0.77x因为 absorbed 版本每个 decode 都要多做一次W_UK重排和einsum在 Lt 小时那点常数开销盖过了收益。长上下文 absorption 变成 1.61x16k 时朴素版本被Wuk(c_kv)拖成 2.2 ms/tokabsorbed 版稳定在 1.36 ms/tok——这就是生产 kernel 要把整件事搬到 CUDA 层的原因Python einsum 都能看出趋势手写 kernel 把常数摁下去之后这条曲线会更陡。作者判断 5在 Python PyTorch 层写 MLA最多只能看到“absorption 在长上下文变快”这种趋势真正的低延迟必须靠 CUDA kernel例如 SGLang 的flash_mla、vLLM 的mla_fwd把 Q_abs、c_kv、k_r 的 gather attention 融进一个 pass。如果你的推理场景是中短上下文4k、batch 很小、模型层数不多MHA/GQA 路径反而更快换 MLA 不一定划算。6. 如果你要把 MLA 放进真实项目先决定这几件事按使用路径分三类一、你直接用 DeepSeek-V2 / V3 / R1 官方 checkpoint。如果是 vLLM 推理升到支持 MLA 专用 kernel 的版本MLA fast path 在 2024 年底开始合进主线2025 年以后已是默认。跑起来就能享受 57x cache 压缩和长上下文 throughput。如果是 Hugging Face Transformers 纯 eager forward你拿到的是我第 2 节那版的性能画风——cache 省了但 decode 慢长序列才打平 MHA。别以此为依据评估 MLA。如果你写单卡 demo 脚本3090/4090请记得d_c512意味着每层每 token 1.1 KB32k 上下文 60 层不过 2 GB这是 MLA 在消费卡跑长上下文 LLM 的核心价值。二、你要把 MLA 塞进自己的非 DeepSeek 架构。给出d_c通常 d_model/2 ~ d_model/4 之间和d_r64 是个稳妥值别调到 8 这种极端值。训练时 Q/K/V 各一条 downup 投影都得带上只做 KV 投影而不做 Q 投影参数量省不了多少但会把收敛搞复杂。预训练初期 MLA 模型收敛速度和 MHA 接近真正的差异在长上下文 eval 时才出现不要拿 1k 上下文的 perplexity 判断它优劣。三、你只想自己推理时加载 MLA 模型。不要尝试把 MLA 模型手工“反”成 MHA 再跑因为W_UK/W_UV拼起来的等效 K/V 维度是n_h × d_h 16384单个 K 矩阵就 256 MB反过来算既慢又占显存。真的想 debug 精度就关掉 absorption用第 2 节那个朴素版跑一遍它和 absorption 版在 fp32 上对得上1e-6、fp16 对得上1e-5是你的 ground truth。7. 参考与延伸阅读DeepSeek-AI,DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model, arXiv:2405.04434MLA 首次系统提出含完整超参与对比实验。DeepSeek-AI,DeepSeek-V3 Technical Report, arXiv:2412.19437MLA 在更大模型上的配置与训练 reciped_c512, d_r64的来源。vLLM 源码vllm/attention/backends/mla.py中的apply_forward与_flash_mla_fwd对应本文第 5 节 absorption decoupled RoPE 的 CUDA 版本。SGLang 源码sglang/srt/layers/attention/flashinfer_mla_backend.py可以对比 FlashInfer 的 MLA kernel 与本文 einsum 版在结构上的异同。本文实验脚本mla_min.py/run_bench.py/absorb_demo.py/absorbed_decode.py复现结果只需要 PyTorch ≥ 2.1 任意 10 GB 显存 GPU所有命令和日志见末尾附录。附录完整实验脚本A.mla_min.py本文第 2 节importmath,torchimporttorch.nnasnnimporttorch.nn.functionalasFdefrotate_half(x):x1,x2x[...,::2],x[...,1::2]returntorch.stack((-x2,x1),dim-1).flatten(-2)defapply_rope(x,cos,sin):returnx*cosrotate_half(x)*sindefbuild_rope(pos,dim,base10000.0,deviceNone,dtypetorch.float32):inv1.0/(base**(torch.arange(0,dim,2,devicedevice,dtypedtype)/dim))freqstorch.outer(pos.to(inv.dtype),inv)embtorch.repeat_interleave(freqs,2,dim-1)returnemb.cos(),emb.sin()classMLA(nn.Module):def__init__(self,d_model1024,n_h16,d_h64,d_c128,d_r32,d_cq256):super().__init__()self.n_h,self.d_h,self.d_c,self.d_r,self.d_cqn_h,d_h,d_c,d_r,d_cq self.Wdqnn.Linear(d_model,d_cq,biasFalse)self.Wuqnn.Linear(d_cq,n_h*d_h,biasFalse)self.Wqrnn.Linear(d_cq,n_h*d_r,biasFalse)self.Wdkvnn.Linear(d_model,d_c,biasFalse)self.Wuknn.Linear(d_c,n_h*d_h,biasFalse)self.Wuvnn.Linear(d_c,n_h*d_h,biasFalse)self.Wkrnn.Linear(d_model,d_r,biasFalse)self.Wonn.Linear(n_h*d_h,d_model,biasFalse)defforward(self,h,pastNone):B,L,_h.shape n_h,d_h,d_c,d_rself.n_h,self.d_h,self.d_c,self.d_r c_kvself.Wdkv(h)k_rself.Wkr(h)ifpastisnotNone:c_kvtorch.cat([past[0],c_kv],dim1)k_rtorch.cat([past[1],k_r],dim1)Ltc_kv.shape[1]qzself.Wdq(h)q_cself.Wuq(qz).view(B,L,n_h,d_h)q_rself.Wqr(qz).view(B,L,n_h,d_r)k_cself.Wuk(c_kv).view(B,Lt,n_h,d_h)vself.Wuv(c_kv).view(B,Lt,n_h,d_h)k_r_hk_r.unsqueeze(2).expand(B,Lt,n_h,d_r)pos_qtorch.arange(Lt-L,Lt,deviceh.device)pos_ktorch.arange(Lt,deviceh.device)cos_q,sin_qbuild_rope(pos_q,d_r,deviceh.device,dtypeh.dtype)cos_k,sin_kbuild_rope(pos_k,d_r,deviceh.device,dtypeh.dtype)q_rapply_rope(q_r,cos_q[None,:,None,:],sin_q[None,:,None,:])k_r_happly_rope(k_r_h,cos_k[None,:,None,:],sin_k[None,:,None,:])qtorch.cat([q_c,q_r],dim-1)ktorch.cat([k_c,k_r_h],dim-1)qq.transpose(1,2);kk.transpose(1,2);vvv.transpose(1,2)oF.scaled_dot_product_attention(q,k,vv,is_causal(pastisNone))returnself.Wo(o.transpose(1,2).contiguous().view(B,L,n_h*d_h)),(c_kv,k_r)### B. KV 字节表和 3090 分配本文第 1 节实验日志 cache equivalence test prefill segment max diff: 0.00e00decoded last token max diff: 4.47e-08 per-token KV bytes config MHA/tok GQA8/tok MLA/tok (60L × 32k)DeepSeek-V2 style 64.00 KB 4.00 KB 1.12 KB MHA120.00 GB GQA7.50 GB MLA2.11 GBLlama-3-70B-like 32.00 KB 4.00 KB 1.12 KB MHA 60.00 GB GQA7.50 GB MLA2.11 GB live 3090 allocation L 2048 MHA/layer 128.00 MB MLA/layer 2.25 MB ratio56.9xL 8192 MHA/layer 512.00 MB MLA/layer 9.00 MB ratio56.9xL 32768 MHA/layer 2.00 GB MLA/layer 36.00 MB ratio56.9xL 65536 MHA/layer 4.00 GB MLA/layer 72.00 MB ratio56.9x### C. absorption 和 decoupled RoPE 的等价/冲突检查第 4 节content-path absorption max diff: 2.38e-06if we RoPE K_C, absorption path diverges by: 1.03e01 (this is the reason decoupled RoPE is mandatory)### D. absorbed decode 与朴素版本 latency 对比第 5 节prefill 1024 base-MLA/tok1.071 ms absorbed/tok1.384 ms correctness max-diff1.53e-05 speedup0.77xprefill 4096 base-MLA/tok1.085 ms absorbed/tok1.343 ms correctness max-diff7.63e-06 speedup0.81xprefill 16384 base-MLA/tok2.200 ms absorbed/tok1.364 ms correctness max-diff3.81e-06 speedup1.61x 测试平台RTX 3090 (24 GB) Driver 590.44.01 CUDA 13.1PyTorch 2.9.1cu128fp16 forwardbatch1d_model2048, n_h16, d_h128, d_c384, d_r32, d_cq1024。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2577807.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!