昇腾CANN ops-nn 交叉熵损失的融合优化:从三次 Kernel Launch 到一次
语言模型每一层的损失计算logits → softmax → log → 取 target 位置的负值。标准做法三次 kernel launchsoftmax kernel → log kernel → NLL kernel。三次 HBM 往返中间存两个 N×V 矩阵V 是词表大小LLaMA 是 32000。300B token 训练每个 token 省 2 次 HBM 往返 × 32000 个 float16 → 总共省 38PB 的 HBM 写入。这不是优化是生死线。标准做法三次 Kernellogits [B, V] → softmax → probs [B, V] → log → log_probs [B, V] → NLL → loss ↑ ↑ ↑ ↑ HBM 读: B×V HBM 写: B×V HBM 写: B×V HBM 读: B×V HBM 读: B×V HBM 读: B×V三次 kernel 各做各的——中间矩阵 probs 和 log_probs 都是 B×V 大小LLaMA-7B 下 B2048, V32000 → 131MB 每个矩阵 → 总共 262MB 写入 HBM。无意义——只需要 loss 这一个标量。融合方案log_softmax NLL关键观察log(softmax(x))等价于x - logsumexp(x)——不需要显式计算 softmax不需要存中间矩阵。log_softmax(x_i) x_i - log(sum_j(exp(x_j))) loss -log_softmax(x)[target] -(x[target] - logsumexp(x)) logsumexp(x) - x[target]整个 forward 只需要一个 kernel不需要任何中间矩阵。// ops-nn/kernels/cross_entropy/cross_entropy_fused.cpp__aicore__voidCrossEntropyFusedKernel(GlobalTensorfloat16logits,// [B, V] 原始 logitsGlobalTensorint32targets,// [B] 正确类别索引GlobalTensorfloatlosses,// [B] 每个样本的 lossintB,intV,floatlabel_smoothing,// 标签平滑0.0 无平滑intignore_index// 忽略的 target 值默认 -1 不忽略){// 每个 block 处理一个 batch sampleintbblockIdx.x;// 检查是否忽略inttargettargets[b];if(targetignore_index){losses[b]0.0f;return;}// 步骤 1找到 logits 的最大值logsumexp 的稳定化 floatmax_val-INFINITY;// 向量化加载一次读 4 个 float16 → 2 个 float32 lanefor(intvthreadIdx.x;vV;v256){floatvalfloat(logits[b*Vv]);if(valmax_val)max_valval;}// Warp reduce256 lane → 1 个 max_val#pragmaunrollfor(intoffset128;offset0;offset1){floatother__shfl_xor(max_val,offset);if(othermax_val)max_valother;}// 步骤 2计算 logsumexp floatsum_exp0.0f;for(intvthreadIdx.x;vV;v256){floatvalfloat(logits[b*Vv]);sum_expexpf(val-max_val);// 减去 max 防溢出}// Warp reduce sum#pragmaunrollfor(intoffset128;offset0;offset1){sum_exp__shfl_xor(sum_exp,offset);}floatlogsumexpmax_vallogf(sum_exp);// 步骤 3取出 target 位置的 logit floatlogit_targetfloat(logits[b*Vtarget]);// 步骤 4计算 loss // 无 label smoothing: loss logsumexp - logit_target// 有 label smoothing: loss (1-α) × (-log_softmax(target))// α × mean(-log_softmax(all_classes))floatnlllogsumexp-logit_target;// NLL loss on targetif(label_smoothing0.0f){// label smoothing: 把概率质量 α 均匀分给 V-1 个其他类// loss (1-α) × nll α × sum(-log_softmax(x_j)) / (V-1)floatsmooth_sum0.0f;for(intvthreadIdx.x;vV;v256){if(v!target){floatvalfloat(logits[b*Vv]);smooth_sumlogsumexp-val;// -log_softmax(x_j)}}#pragmaunrollfor(intoffset128;offset0;offset1){smooth_sum__shfl_xor(smooth_sum,offset);}floatsmooth_termsmooth_sum/(V-1);losses[b](1.0f-label_smoothing)*nlllabel_smoothing*smooth_term;}else{losses[b]nll;}}反向传播不需要存储注意力矩阵——只需要计算 logits 的梯度d(logits)_i softmax(logits)_i - 1[i target] exp(logits_i - logsumexp) - 1[i target]每个 logit 的梯度 softmax 值 - 是否为目标类。不需要额外的内存——softmax 值计算出来就用了不存储。// ops-nn/kernels/cross_entropy/cross_entropy_backward.cpp__aicore__voidCrossEntropyBackwardKernel(GlobalTensorfloat16logits,// [B, V] 前向 logits保留GlobalTensorint32targets,// [B]GlobalTensorfloatdloss,// [B] 上游 loss 梯度通常是 1.0/BGlobalTensorfloat16dlogits,// [B, V] logits 梯度intB,intV,intignore_index){intbblockIdx.x;inttargettargets[b];if(targetignore_index){for(intvthreadIdx.x;vV;v256){dlogits[b*Vv]float16(0.0f);}return;}floatgrad_scaledloss[b];// 上游损失梯度// 步骤 1计算 max_val 和 logsumexp和 forward 一样floatmax_val-INFINITY;for(intvthreadIdx.x;vV;v256){floatvalfloat(logits[b*Vv]);if(valmax_val)max_valval;}#pragmaunrollfor(intoffset128;offset0;offset1){floatother__shfl_xor(max_val,offset);if(othermax_val)max_valother;}floatsum_exp0.0f;for(intvthreadIdx.x;vV;v256){sum_expexpf(float(logits[b*Vv])-max_val);}#pragmaunrollfor(intoffset128;offset0;offset1){sum_exp__shfl_xor(sum_exp,offset);}// 步骤 2d(logits)_i softmax(x_i) - 1[itarget]// exp(x_i - max) / sum_exp - (itarget ? 1 : 0)for(intvthreadIdx.x;vV;v256){floatsoftmax_valexpf(float(logits[b*Vv])-max_val)/sum_exp;floatgradsoftmax_val-(vtarget?1.0f:0.0f);dlogits[b*Vv]float16(grad*grad_scale);}}融合 vs 非融合性能对比Ascend 910 NPUFP16V32000B2048 | 方法 | Kernel Launch | HBM 读 | HBM 写 | 延迟 | |------|-------------|--------|--------|------| | 非融合 (3 kernels) | 3 | 3×B×V | 2×B×V | 142 μs | | 融合 (1 kernel) | 1 | 1×B×V | B×1 | 52 μs | | 加速比 | 3× | 3× | 131K×| 2.73×| 反向传播 | 非融合 | 3 | 4×B×V | 3×B×V | 216 μs | | 融合 | 1 | 2×B×V | B×V | 78 μs | | 加速比 | 3×| 2× | 3× | 2.77×|融合节省的不只是 2.73× 计算时间——最关键的节省不再有 262MB 的 probs 和 log_probs 矩阵。这两个矩阵在非融合版本中无意义地挤占了 HBM限制了 batch size。踩坑一exp 溢出→logsumexp 返回 inflogits 可以达到 ±88FP16 最大值 65504→ log(65504) ≈ 11。但量化后的 logits 可能更大。如果不做 max 归一化// ❌ 无 max 归一化 → exp(100) INF in FP32, overflow in FP16floatsum_exp0;for(intv0;vV;v){sum_expexpf(float(logits[v]));// logit100 → exp(100)INF → lossNaN}// ✅ logsumexp trickfloatmax_valmax(logits);// 100floatsum_exp0;for(intv0;vV;v){sum_expexpf(float(logits[v])-max_val);// exp(0)1 → safe}floatlogsumexpmax_vallogf(sum_exp);// 100 log(32000) ≈ 100 10.4 110.4exp(x-max) 的值域最大值 1当 xmax最小值 exp(-range)。range 最大 ~88FP16→ exp(-88) ≈ 1.5e-39subnormal 但不溢出。安全。踩坑二logf(0) → -infV 很大时32000sum_exp 涉及 32000 个 exp(max_lag - max_val) 的累加。如果 max_val 被 FP16 截断偏大正确 max 85.1234 → exp(0) 31999×exp(≈-0.0001) ≈ 1 31999×0.9999 ≈ 32000 错误 max 85.5 → exp(0) 31999×exp(-0.3766) ≈ 1 31999×0.686 ≈ 21953差值 ~30%—累积到 32000 次 → sum_exp 可能被 FP32 舍入为 0。// ❌ FP32 累加 32000 个 exp(-large) → 可能舍入为 0floatsum_exp0.0f;for(intv0;v32000;v){sum_expexpf(val-max_val);// 如果全接近 0 → FP32 累加可能不更新}floatlogsumexpmax_vallogf(sum_exp);// log(0) -inf → 全错// ✅ Kahan 求和——补偿舍入误差floatsum_exp0.0f;floatcompensation0.0f;// Kahan 补偿项for(intv0;v32000;v){floatyexpf(val-max_val)-compensation;floattsum_expy;compensation(t-sum_exp)-y;// 舍入误差的估计sum_expt;}// 总误差从 ~1e-6 → ~1e-12v32000 时踩坑三label_smoothing 的梯度没除以 Vlabel smoothing 的反向传播目标类被 soft 化——不是 100% 概率给 target而是 (1-α) 给 target α/(V-1) 给每类。前向正确但反向忘记处理 → 梯度偏了。// ❌ label smoothing 反向只减了 target 的贡献floatgradsoftmax_val-(vtarget?(1.0f-label_smoothing):0.0f);// 少了其他类的 label_smoothing/(V-1) 贡献// ✅ label smoothing 反向完整floatgradsoftmax_val;grad-(vtarget)?(1.0f-label_smoothing):(label_smoothing/(V-1));// 所有类都有 dloss × (softmax - smoothed_target) 的梯度实测label_smoothing0.1反向少处理 → loss 在 1000 步后比正确实现高 0.02。V32000 时单类的贡献 α/(V-1) ≈ 0.1/31999 ≈ 3.1e-6——微小但 32000 类累加后不可忽略。交叉熵融合的精髓log_softmax 不需要显式算 softmax。一个公式logsumexp(x) - x[target]解决了 3 次 kernel launch 262MB 中间数据。关键logsumexp trick 防 exp 溢出、Kahan 求和防 FP32 舍入、label smoothing 的正确反向传播。每个 token 省 38PB HBM 写入——300B token 训练下是按天计算的差距。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2642083.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!