CANN Gather算子API描述
Gather 算子 API 描述【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench1. 算子简介从输入 Tensor 的指定维度按 index 提取元素。主要应用场景嵌入层Embedding的查表操作注意力机制中按索引提取 Key/Value稀疏操作中按索引收集特征算子特征难度等级L2IndexGather双输入x 和 index单输出y按索引进行元素提取输入支持 ND 格式支持任意维度2. 算子定义数学公式$$ y[i][m][n] x[index[i]][m][n] $$更一般地对于batch_dimsk前 k 个维度作为 batch 维度在第 k 个维度上按 index 进行 gather 操作。3. 接口规范算子原型cann_bench.gather(Tensor x, Tensor index, int batch_dims) - Tensor y输入参数说明参数类型默认值描述xTensor必选输入张量indexTensor必选索引张量batch_dimsINT640batch 维度数输出参数Shapedtype描述y由 index shape 和 x 的非 gather 维度决定与输入 x 相同输出张量gather 结果数据类型x dtypeindex dtype输出 dtypefloat16int32 / int64float16float32int32 / int64float32bfloat16int32 / int64bfloat16int8int32 / int64int8int32int32 / int64int32int64int32 / int64int64规则与约束输入支持任意维度的 ND 格式张量batch_dims指定 batch 维度数前batch_dims个维度作为 batch 维度x 和 index 在这些维度上的大小必须一致index 中的值必须为有效索引即在 [0, x.shape[batch_dims]) 范围内输出 dtype 与输入 x 的 dtype 一致index 张量在 gather 维度之外的维度上shape 必须与 x 对应维度一致4. 精度要求采用生态算子精度标准进行验证。误差指标平均相对误差MERE采样点中相对误差平均值$$ \text{MERE} \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$最大相对误差MARE采样点中相对误差最大值$$ \text{MARE} \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$通过标准数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2当平均相对误差 MERE Threshold最大相对误差 MARE 10 * Threshold 时判定为通过。5. 标准 Golden 代码import torch Gather算子Torch Golden参考实现 从输入Tensor的指定维度按index提取元素 公式: y[i][m][n] x[index[i]][m][n] def gather( x: torch.Tensor, index: torch.Tensor, batch_dims: int 0 ) - torch.Tensor: 从输入Tensor的指定维度按index提取元素 公式: y[i][m][n] x[index[i]][m][n] Args: x: 输入张量 index: 索引张量 batch_dims: batch维度数 Returns: 输出张量gather结果 y torch.gather(x, batch_dims, index.long()) return y6. 额外信息算子调用示例import torch import cann_bench x torch.randn(1024, 1024, dtypetorch.float32, devicenpu) index torch.randint(0, 1024, (512, 1024), dtypetorch.int32, devicenpu) y cann_bench.gather(x, index, batch_dims0) # 沿第 0 维 gather x torch.randn(128, 128, 64, dtypetorch.float16, devicenpu) index torch.randint(0, 128, (128, 64, 64), dtypetorch.int64, devicenpu) y cann_bench.gather(x, index, batch_dims1) # batch_dims1【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2598858.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!