CANN ops-transformer:MC2 通信融合算子怎么加速 MoE 的 All-to-All
MoE 的 Expert Parallel 需要全互连通信——每个 token 发给它路由到的专家所在的卡再收回来。这个 All-to-All 通信在 8 卡 MoE 上能占 30% 的推理时间。MC2Merge-Communicate-Split把通信和计算融合在一起在等数据的时候不闲着。All-to-All 通信的瓶颈先说清楚 All-to-All 是什么。在 Expert ParallelEP模式下每个专家分布在不同的 GPU/NPU 上。LLaMA-MoE 的 8 个专家分布在 8 张卡上每张卡有一个专家。每个 token 要经过两步通信第一步发送 token 到它被路由到的专家所在的卡。比如 token A 被路由到专家 3专家 3 在卡 3 上所以 token A 要从卡 0 发送到卡 3。第二步专家处理完 token结果要发回原卡。8 张卡全互连每张卡同时发送和接收总通信量是 8 × tensor_size × 平均跳数跳数取决于物理拓扑。All-to-All 的瓶颈是延迟而不是带宽。小消息多每个 token 只有几百个浮点数每条消息都需要握手、同步、网络排队。这些开销累积起来通信时间占 30%。MC2 的核心思路MC2 的核心观察是等数据的时候不闲着。标准流程是先发完所有数据等待等所有数据到齐了再开始计算处理。MC2 的流程是发第一个专家的数据同时准备第二个专家的输入等第一个专家的结果时同时发第三个专家的数据。通信和计算完全流水线。MC2 拆成三个阶段Merge 阶段把要发送给同一个专家的数据打包在一起。不同 token 路由到不同的专家Merge 阶段把它们按目标专家分组减少通信次数。Communicate 阶段用 HCCL 的 All-to-All 发送数据。MC2 把 All-to-All 拆成多个小批次和计算流水并行。Split 阶段收到的数据按来源卡分组把结果分发给各自的原卡。Ascend C 实现// MC2 融合算子的 Ascend C 核心逻辑// MC2 的关键设计让通信和计算 overlap// 每张卡维护一个状态机管理 8 个专家的处理状态// 专家处理状态枚举enumExpertState{WAITING_INPUT,// 等待输入数据PROCESSING,// 正在处理计算中SENDING_OUTPUT,// 正在发送输出DONE// 处理完成};// MC2 的主循环轮询 8 个专家的状态调度计算和通信__aicore__voidMC2Kernel(GM_ADDR local_expert_output,...){ExpertState states[NUM_EXPERTS]{WAITING_INPUT};bool all_donefalse;while(!all_done){// 遍历每个专家的状态for(inte0;eNUM_EXPERTS;e){switch(states[e]){caseWAITING_INPUT:// 检查是否有数据到达通过 HBM 标志位判断if(CheckDataArrived(e)){// 数据到达开始处理LoadExpertInput(e);// 从 HBM 读到 UBstates[e]PROCESSING;}break;casePROCESSING:// 专家计算可能需要多个迭代if(IsComputeDone(e)){// 计算完成准备发送states[e]SENDING_OUTPUT;}break;caseSENDING_OUTPUT:// 发送输出异步的数据写入 HBM 缓冲区由 driver 处理网络传输WriteOutputToHBM(e);NotifyPeerComplete(e);// 发信号给对端states[e]DONE;break;caseDONE:// 这个专家的处理已完成break;}}// 检查是否所有专家都完成了all_doneCheckAllDone(states);// 给 HCCL 通信线程一个机会往前推进YieldToHCCL();// 让 driver 处理已到达的网络数据}}MC2 的关键实现细节是 HCCL 的异步执行。All-to-All 的发送操作是异步的——调用 HCCL 的 All-to-All 接口只是发起传输实际数据还在搬运。MC2 在发起 All-to-All 后不阻塞等结果而是切换到其他专家的计算。等 All-to-All 完成通过 event 通知再切换回来处理结果。这个实现依赖昇腾 runtime 的异步通信 API 和 Event 同步机制。ops-transformer 仓的moe_mc2_fusion.cpp里有完整的实现。MC2 vs 标准 All-to-All标准 All-to-All 的时间线T0 发起通信 → T1 等待所有数据到达 → T2 全部数据到达开始计算 → T3 计算完成。总延迟 (T1-T0) (T3-T2)。MC2 的时间线T0 发起第 1 批通信同时开始准备第 2 批数据 → T1 第 1 批数据到达开始处理同时发起第 2 批通信 → T2 第 1 批处理完成开始发送同时处理第 2 批数据 → … → T3 全部完成。总延迟 T1-T0 T3-T2但 T3-T2 因为 overlap 大幅缩短。MC2 的收益取决于通信和计算的比率。如果专家计算很快专家参数量小通信时间占比高MC2 收益就大。如果专家计算很慢专家参数量大比如 Mixtral 的 experts8×7B通信时间占比低MC2 收益就小。实测数据Mixtral-8×7BEP8MC2 让 All-to-All 阶段的延迟从 25ms 降到 18ms节省约 28%。MC2 的收益随 EP 规模增长——EP2 时All-to-All 的总数据量和通信距离都小MC2 收益只有 5-8%。EP8 时收益达到 25-30%。https://atomgit.com/cann/ops-transformer
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2638859.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!