CANN/ops-transformer Floyd注意力梯度算子
FusedFloydAttentionGrad【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer产品支持情况产品是否支持Ascend 950PR/Ascend 950DT×Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√Atlas 200I/500 A2 推理产品×Atlas 推理系列产品×Atlas 训练系列产品×功能说明算子功能训练场景下计算Floyd注意力的反向输出FloydAttn相较于传统FA主要是计算qk/pv注意力时会额外将seq作为batch轴从而转换为batchMatmul。计算公式已知注意力的正向计算公式为$$ PSoftmax(Mask(scale*(QK_1^T QK_2^T), atten_mask)) \ Y(PV_1PV_2) $$则注意力的反向计算公式为$$ SSoftmax(S) $$$$ dV_1P^TdY $$$$ dV_2P^TdY $$$$ dQ\frac{((dS)*K_1)}{\sqrt{d}}\frac{((dS)*K_2)}{\sqrt{d}} $$$$ dK_1\frac{((dS)^T*Q)}{\sqrt{d}} $$$$ dK_2\frac{((dS)^T*Q)}{\sqrt{d}} $$参数说明参数名输入/输出/属性描述数据类型数据格式query输入公式中的输入Q。FLOAT16、BFLOAT16NDkey1输入公式中的输入K1。FLOAT16、BFLOAT16NDvalue1输入公式中的输入V1。FLOAT16、BFLOAT16NDkey2输入公式中的输入K2。FLOAT16、BFLOAT16NDvalue2输入公式中的输入V2。FLOAT16、BFLOAT16NDdy输入公式中的输入dY。FLOAT16、BFLOAT16NDattenMaskOptional可选输入公式中的atten_mask表示注意力掩码取值为1代表该位不参与计算不生效为0代表该位参与计算。BOOL、UINT8NDscaleValue可选属性公式中的scale表示缩放系数作为计算流中Muls的scalar值。默认值为1.0。DOUBLE-dqOut输出公式中的dQ表示query的梯度。FLOAT16、BFLOAT16NDdk1Out输出公式中的dK1表示key1的梯度。FLOAT16、BFLOAT16NDdv1Out输出公式中的dV1表示value1的梯度。FLOAT16、BFLOAT16NDdk2Out输出公式中的dK2表示key2的梯度。FLOAT16、BFLOAT16NDdv2Out输出公式中的dV2表示value2的梯度。FLOAT16、BFLOAT16ND约束说明该接口与PyTorch配合使用时需要保证CANN相关包与PyTorch相关包的版本匹配关于数据shape的约束其中B取值范围为1~2K。H取值范围为1~256。N取值范围为16~1M且N%160。M取值范围为128~1M且M%1280。K取值范围为128~1M且K%1280。D取值范围为32/64/128。query与key1的第0/2/4轴需相同。key1与value1 shape需相同。key2与value2 shape需相同。query与dy/attentionIn shape需相同。softmaxMax与softmaxSum shape需相同。D只支持32/64/128。调用说明调用方式调用样例说明aclnn调用test_aclnn_fused_floyd_attention_grad通过接口方式调用aclnnFusedFloydAttentionGrad算子。【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2599844.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!