CANN Triton排序选择算子优化
Sort/Select 算子优化【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills适用于需要迭代选择元素的算子NMS、TopK、ArgSort 等核心约束Triton Ascend 不支持break/continue/return和 Pythonif分支必须用tl.where mask 实现条件逻辑。禁止语法替代方案说明if cond:tl.where(cond, a, b)所有条件必须用 SIMD 友好的方式表达break/continue用循环变量 mask 控制循环次数固定用 mask 跳过无效迭代return无法提前返回所有路径必须执行到函数末尾标量条件赋值x y if condx tl.where(cond, y, x)标量变量更新必须用tl.where1.2 迭代选择的标准模式对于需要每次从剩余元素中选一个最优的算法如NMS标准模式是# 模式selection-sort 风格的迭代选择 for step in range(max_select): # 1. 线性扫描找最优候选 best_idx -1 best_score threshold for i in range(n_elements): score tl.load(scores_ptr i) higher (score best_score) active best_idx tl.where(higher, i, best_idx) best_score tl.where(higher, score, best_score) # 2. 检查是否找到有效候选 found (best_idx ! -1) active # 3. 记录结果仅当 found 时 tl.store(output_ptr count, best_idx.to(tl.int32), maskfound) count tl.where(found, count 1, count) # 4. 标记已选通过修改内存状态 tl.store(scores_ptr best_idx, sentinel_value, maskfound) # 5. 根据选中元素更新其他元素状态算子特定逻辑 # ... 例如 NMS 中计算 IoU 并抑制重叠 box关键原则用内存值如将 score 设为哨兵值表示已选/已抑制状态而非标量 flag用tl.where做所有条件选择不用 Pythonif用mask参数控制tl.load/tl.store的执行2. 算子特定实现2.1 NMS (Non-Maximum Suppression)算法语义验证框架对比的是 PyTorch 参考实现如30_NMS.py其语义通常包含先过滤只保留满足门槛条件的元素如score scores_threshold再降序排序参考实现通常用torch.argsort(..., descendingTrue, stableTrue)确定顺序迭代选择按排序后的顺序遍历若当前元素未被抑制则选中依赖抑制选中后根据算子特定规则抑制其他元素如 NMS 的 IoU 阈值数量限制最多输出max_output_size个达到即停止输出格式输出张量前num_selected个有效其余为 0 或哨兵值关键降序关系来自参考实现的排序步骤。Triton kernel 中没有显式排序而是通过迭代选择最高分来隐式复现降序语义。参考实现triton.jit def select_kernel( values_ptr, # 用于比较的值 selected_indices_ptr, # 输出选中的原始索引 num_selected_ptr, # 输出实际选中数量 n_elements, max_output_size: tl.constexpr, threshold: tl.constexpr, ): pid tl.program_id(0) active (pid 0) selected_count 0 for step in range(max_output_size): # 1. 线性扫描找最优候选 best_idx -1 best_val threshold for i in range(n_elements): val tl.load(values_ptr i) better (val best_val) active best_idx tl.where(better, i, best_idx) best_val tl.where(better, val, best_val) # 2. 检查是否找到有效候选 found (best_idx ! -1) active # 3. 记录结果 tl.store(selected_indices_ptr selected_count, best_idx.to(tl.int32), maskfound) selected_count tl.where(found, selected_count 1, selected_count) # 4. 标记已选防止重复 tl.store(values_ptr best_idx, sentinel_value, maskfound) # 5. 算子特定逻辑根据选中元素更新其他元素状态 # - NMS读取选中元素的数据计算与其他元素的关系如 IoU # 将满足条件的其他元素标记为已选/已抑制 # - TopK无需此步骤 # - 其他算子根据业务规则更新其他元素的值或标记 tl.store(num_selected_ptr, selected_count, maskactive)关键点:grid(1,)单核执行顺序依赖算法天然难以并行best_idx -1初始值配合found (best_idx ! -1)判断是否找到有效元素maskfound保护所有依赖best_idx的 load/store避免 -1 越界写入顺序自然为降序与参考实现argsort(descendingTrue)语义一致算子特定扩展NMS在通用模式阶段5加入读取选中 box 坐标计算与其他 box 的 IoU将 IoU threshold 的 box 的 score 设为哨兵值抑制。关键点:scores_f32 scores.float().contiguous()保证连续内存访问输出前num_selected个为原始索引按 score 降序其余为 0TopK无抑制逻辑阶段5为空。将哨兵值设为-float(inf)。常见错误# 错误Python if 分支 if score best_score: best_idx i # 正确tl.where best_idx tl.where(score best_score, i, best_idx)# 错误标量 flag 累积 keep True for j in range(n): if iou threshold: keep False # 正确通过内存状态传递 tl.store(scores_ptr j, -1.0, masksuppress)# 错误先收集所有保留元素再截断破坏降序 # 正确每次迭代只选一个天然满足降序和数量限制【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2630347.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!