GQA:多查少算的 Attention 头组合
本文基于昇腾CANN和昇腾NPU围绕 ops-transformer 仓库的相关技术展开。MHAMulti-Head Attention每个 Head 一套 QKV——8 个 Head 就是 8 组。MQA 省过头了——8 个 Head 共享 K、V。GQAGrouped Query Attention走在中间8 个 Head 分 4 组组内共享 K、V。CANN 的 ops-transformer 库用 Ascend C 把 GQA 做成融合算子避免了冗余的 K、V 搬运。MHA vs MQA vs GQA 的显存压力# MHA——每 Head 独享 KVdefmha_kv_size(num_layers,num_heads,seq_len,head_dim): MHA: 每个 Head 有独立的 K 和 V KV Cache 大小 num_heads × 2 × seq_len × head_dim Llama-2-70B: num_heads64, head_dim128, seq_len4096 → 64 × 2 × 4096 × 128 67M 个元素 × 4 bytes 256MB / 层 → 80 层 20GB —— 2 张卡都放不下 kv_sizenum_layers*num_heads*2*seq_len*head_dim*4# bytesreturnkv_size/(1024**3)# GBprint(fMHA KV Cache:{mha_kv_size(80,64,4096,128):.1f}GB)# 输出20.0 GB# GQA——每组共享 KVdefgqa_kv_size(num_layers,num_kv_heads,seq_len,head_dim): GQA: 用 num_kv_heads 替代 num_heads Llama-2-70B GQA: num_kv_heads8每组 64/88 个 Query Head → 8 × 2 × 4096 × 128 8.4M / 层 → 80 层 2.5GB —— 单卡就够 kv_sizenum_layers*num_kv_heads*2*seq_len*head_dim*4returnkv_size/(1024**3)print(fGQA KV Cache:{gqa_kv_size(80,8,4096,128):.1f}GB)# 输出2.5 GBMHA 要 20GB 存 KV Cache——80 层跑不了单卡。GQA 砍到 2.5GB余下的 77.5GB HBM 给模型权重。GQA 的计算过程# GQA 的 Attention 计算——组内 Query 共享一组 KVimporttorchimporttorch.nn.functionalasFclassGQAAttention(torch.nn.Module):def__init__(self,hidden_dim,num_heads,num_kv_heads):super().__init__()assertnum_heads%num_kv_heads0,Query Heads 数必须是 KV Heads 的整数倍self.num_headsnum_heads# 32self.num_kv_headsnum_kv_heads# 8self.head_dimhidden_dim//num_heads# 128self.groupsnum_heads//num_kv_heads# 4# Q 投影hidden_dim → num_heads × head_dimself.q_projtorch.nn.Linear(hidden_dim,num_heads*self.head_dim)# K、V 投影hidden_dim → num_kv_heads × head_dim比 MHA 小 4 倍self.k_projtorch.nn.Linear(hidden_dim,num_kv_heads*self.head_dim)self.v_projtorch.nn.Linear(hidden_dim,num_kv_heads*self.head_dim)defforward(self,x,past_kvNone):B,S,Hx.shape qself.q_proj(x).reshape(B,S,self.num_heads,self.head_dim)kself.k_proj(x).reshape(B,S,self.num_kv_heads,self.head_dim)vself.v_proj(x).reshape(B,S,self.num_kv_heads,self.head_dim)# 关键步骤把 KV 头广播到每组 Query Head# [B, S, 8, 128] → [B, S, 32, 128]kk.repeat_interleave(self.groups,dim2)# 复制 Kvv.repeat_interleave(self.groups,dim2)# 复制 V# 标准 Attention——现在每个 Q 有对应的 K、Vscoretorch.matmul(q.transpose(1,2),k.transpose(1,2).transpose(-2,-1))scorescore/(self.head_dim**0.5)attnF.softmax(score,dim-1)outtorch.matmul(attn,v.transpose(1,2))returnout关键在repeat_interleave——把 8 组 K、V 广播成 32 份。显存省了 8 倍但计算时多了这下复制。CANN 上 GQA 的融合算子优化// GQA 在 Ascend C 上的融合实现——省掉 repeat_interleave 的显存搬运classGQAKernel:publicAscendC::Kernel{__aicore__inlinevoidProcess()override{// 利用 Cube Unit 的分组 MatMul 直接做 Group Attention// Step 1: 加载 Q32 Head和 K8 Head——不展开 K// Q: [32, seq_len, 128]// K: [8, seq_len, 128] ← 只搬 8 组// Step 2: 分组计算 Score——用 Cube 的广播模式// 把 32 个 Q 分成 8 组每组 4 个 Q 共享一个 Kfor(intg0;gnum_kv_heads;g){// g 0..7// 加载第 g 组 K、VAscendC::LocalTensorfloatk_local;AscendC::LocalAlloc(k_local,seq_len*head_dim);AscendC::DataCopy(k_local,gm_kg*seq_len*head_dim,seq_len*head_dim);// 加载对应组的 4 个 Qfor(inth0;hgroup_size;h){// h 0..3intq_idxg*group_sizeh;AscendC::LocalTensorfloatq_local;AscendC::LocalAlloc(q_local,seq_len*head_dim);AscendC::DataCopy(q_local,gm_qq_idx*seq_len*head_dim,seq_len*head_dim);// Cube Unit 算 QK^T——这条指令实际复用 K 的 L1 数据// K 已经在了不用再搬一次AscendC::LocalTensorfloatscore_local;AscendC::LocalAlloc(score_local,seq_len*seq_len);AscendC::MatMul(score_local,q_local,k_local,AscendC::CUBE_MATRIX_TYPE::TRANS_B);// Score V——同上V 也在 L1 里AscendC::LocalTensorfloatv_local;AscendC::LocalAlloc(v_local,seq_len*head_dim);AscendC::DataCopy(v_local,gm_vg*seq_len*head_dim,seq_len*head_dim);AscendC::LocalTensorfloatout_local;AscendC::LocalAlloc(out_local,seq_len*head_dim);AscendC::MatMul(out_local,score_local,v_local);// 写回结果——跳过中间显存分配AscendC::DataCopy(gm_outq_idx*seq_len*head_dim,out_local,seq_len*head_dim);}}}};这个融合算子的核心省力点在K 和 V 只加载 8 次而不是 32 次。每组内的 4 个 Q 复用同一份 K、V 的 L1 数据——搬运量减少 75%。Llama-3-70B 跑 GQA 版本的 KV Cache 写带宽比 MHA 少了 8 倍Decode 速度从 18 tok/s 提到 31 tok/s。参考仓库GQA 等 Attention 算子Transformer 加速库
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2638748.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!