CANN/ops-nn: 原位加法RMS归一化算子
InplaceAddRmsNorm【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品/Atlas A3 推理系列产品√Atlas A2 训练系列产品/Atlas A2 推理系列产品√Atlas 200I/500 A2 推理产品×Atlas 推理系列产品√Atlas 训练系列产品×Kirin X90 处理器系列产品√Kirin 9030 处理器系列产品√功能说明算子功能RmsNorm算子是大模型常用的归一化操作相比LayerNorm算子其去掉了减去均值的部分。AddRmsNorm算子将RmsNorm前的Add算子融合起来减少搬入搬出操作。InplaceAddRmsNorm是一种结合了原位加法和RMS归一化的操作。计算公式$$ x_ix1_{i}x2_{i} $$$$ \operatorname{RmsNorm}(x_i) g_i * (x_i * \operatorname{rstd}(\mathbf{x})), \quad \text { where } \operatorname{rstd}(\mathbf{x})\frac{1}{\sqrt{\frac{1}{n} \sum_{i1}^n x_i^2eps}} $$参数说明参数名输入/输出/属性描述数据类型数据格式x1输入用于Add计算的第一个输入对应公式中的x1。FLOAT32、FLOAT16、BFLOAT16NDx2输入用于Add计算的第二个输入对应公式中的x2。FLOAT32、FLOAT16、BFLOAT16NDgamma输入表示RmsNorm的缩放因子权重对应公式中的g。FLOAT32、FLOAT16、BFLOAT16NDepsilon可选属性添加到分母中的值以确保数值稳定用于防止除0错误对应公式中的eps。默认值为1e-6f。FLOAT32-x1输出表示最后的输出Device侧的aclTensor对应公式中的RmsNorm(x)。FLOAT32、FLOAT16、BFLOAT16NDrstd输出表示归一化后的标准差倒数对应公式中的rstd。FLOAT32NDx2输出表示Add计算的结果对应公式中的x。FLOAT32、FLOAT16、BFLOAT16NDAtlas 推理系列产品 所有输入参数和输出参数x1、x2的数据类型不支持BFLOAT16。在当前产品下的使用场景下输出参数rstd为无效参数输出的值不生效。Kirin X90/Kirin 9030处理器系列产品所有输入、输出的数据类型不支持BFLOAT16。约束说明无调用说明调用方式样例代码说明aclnn接口test_aclnn_inplace_add_rms_norm通过aclnnInplaceAddRmsNorm接口方式调用InplaceAddRmsNorm算子。图模式-通过算子IR构图方式调用InplaceAddRmsNorm算子。【免费下载链接】ops-nn本项目是CANN提供的神经网络类计算算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-nn创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2602339.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!