CANN/pypto量化矩阵乘法
pypto.scaled_mm【免费下载链接】pyptoPyPTO发音: pai p-t-oParallel Tensor/Tile Operation编程范式。项目地址: https://gitcode.com/cann/pypto产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√功能说明实现mat_a 、mat_b矩阵的mx量化矩阵乘运算计算公式为out (mat_a * scale_a) (mat_b * scale_b)mat_a 、mat_b 、scale_a 、scale_b为源操作数mat_a 为左矩阵mat_b为右矩阵scale_a为左矩阵量化参数scale_b为右矩阵量化参数out 为目的操作数存放矩阵乘结果的矩阵函数原型scaled_mm(mat_a, mat_b, out_dtype, scale_a, scale_b, *, a_trans False, b_trans False, scale_a_trans False, scale_b_trans False, c_matrix_nz False, extend_paramsNone) - Tensor参数说明参数名输入/输出说明mat_a输入表示输入左矩阵。不支持输入空Tensor。支持的数据类型为DT_FP8E5M2, DT_FP8E4M3且左右矩阵数据类型需保持一致。支持的矩阵维度2维。输入矩阵支持的Format为TILEOP_ND, TILEOP_NZDT_FP8E5M2输入不支持TILEOP_NZ格式。内轴外轴当输入矩阵mat_a非转置时对应数据排布为[M, K]此时外轴为M内轴为K当输入矩阵mat_a转置时对应数据排布为[K, M]此时外轴为K内轴为M。当Format为TILEOP_NDND格式时外轴范围为[1, 2^31 - 1]内轴范围为[1, 65535]。当Format为TILEOP_NZNZ格式时其Shape维度需满足内轴32字节对齐外轴16元素对齐。在满足Format约束的基础上其Shape维度需满足K轴64元素对齐。在使用pypto.view接口的场景应保证传入View的Shape维度也满足内轴32字节对齐外轴16元素对齐。mat_b输入表示输入右矩阵。不支持输入空Tensor。支持的数据类型为DT_FP8E5M2, DT_FP8E4M3且左右矩阵数据类型需保持一致。支持的矩阵维度2维。输入矩阵支持的Format为TILEOP_ND, TILEOP_NZDT_FP8E5M2输入不支持TILEOP_NZ格式。内轴外轴当输入矩阵mat_b非转置时对应数据排布为[K, N]此时外轴为K内轴为N当输入矩阵mat_b转置时对应数据排布为[N, K]此时外轴为N内轴为K。当Format为TILEOP_NDND格式时外轴范围为[1, 2^31 - 1]内轴范围为[1, 65535]。当Format为TILEOP_NZNZ格式时其Shape维度需满足内轴32字节对齐外轴16元素对齐。在满足Format约束的基础上其Shape维度需满足K轴64元素对齐。在使用pypto.view接口的场景应保证传入View的Shape维度也满足内轴32字节对齐外轴16元素对齐。out_dtype输出表示输出矩阵数据类型支持DT_FP32DT_FP16DT_BF16。scale_a输入表示输入左矩阵量化参数。不支持输入空Tensor。支持的数据类型为DT_FP8E8M0。支持的量化参数维度3维。输入量化参数shape为当输入量化参数非转置时对应输入shape为[M, K/64, 2]当输入量化参数转置时对应输入shape为[K/64, M, 2]。其中M和K值等于输入矩阵mat_a的M、K值。输入量化参数支持的Format为TILEOP_ND。scale_b输入表示输入右矩阵量化参数。不支持输入空Tensor。支持的数据类型为DT_FP8E8M0。支持的量化参数维度3维。输入量化参数shape为当输入量化参数非转置时对应输入shape为[K/64, N, 2]当输入量化参数转置时对应输入shape为[N, K/64, 2]。其中M和K值等于输入矩阵mat_a的M、K值。输入量化参数支持的Format为TILEOP_ND。a_trans输入参数a_trans表示输入左矩阵是否转置默认为False。b_trans输入参数b_trans表示输入右矩阵是否转置默认为False。scale_a_trans输入参数scale_a_trans表示输入左矩阵量化参数是否转置默认为False。scale_b_trans输入参数scale_b_trans表示输入右矩阵量化参数是否转置默认为False。c_matrix_nz输入参数c_matrix_nz表示输出矩阵的Format是否采用NZ格式默认为False当前仅支持设置False即输出矩阵仅支持ND格式。extend_params输入支持bias及fixpipe的反量化功能数据类型为字典格式。默认为None当前仅支持bias场景。详见表2表2extend_params参数说明参数名说明bias_tensor表示偏置矩阵。输入为Tensor类型。Bias矩阵数据类型可选DT_FP16、DT_BF16和DT_FP32。bias_tensor只支持ND格式。bias_tensor的第一维度应置1且N维度需要与mat_b矩阵的N维度相等。仅支持矩阵维度为2维场景。不支持叠加多核切K功能。返回值说明返回值为out 矩阵Tensor。约束说明调用scaled_mm接口前需要通过pypto.set_cube_tile_shapes设置M、N、K轴上的切分大小。调用scaled_mm接口的输入为调用pypto.reshape后的NZ格式时需要调用pypto.set_matrix_size接口设置pypto.reshape前的输入到matmul的原始Shape的m,k,n值。调用示例mat_a pypto.tensor([64, 128], pypto.DT_FP8E5M2, mat_a) mat_b pypto.tensor([128, 32], pypto.DT_FP8E5M2, mat_b) scale_a pypto.tensor([64, 2, 2], pypto.DT_FP8E8M0, scale_a) scale_b pypto.tensor([2, 32, 2], pypto.DT_FP8E8M0, scale_b) out1 pypto.scaled_mm(mat_a, mat_b, pypto.DT_BF16, scale_a, scale_b) mat_a pypto.tensor([128, 64], pypto.DT_FP8E5M2, mat_a) mat_b pypto.tensor([32, 128], pypto.DT_FP8E5M2, mat_b) scale_a pypto.tensor([2, 64, 2], pypto.DT_FP8E8M0, scale_a) scale_b pypto.tensor([32, 2, 2], pypto.DT_FP8E8M0, scale_b) bias pypto.tensor((1, 32), pypto.DT_FP16, tensor_bias) extend_params {bias_tensor: bias} out1 pypto.scaled_mm(mat_a, mat_b, pypto.DT_BF16, scale_a, scale_b, scale_a_transTrue, scale_b_transTrue, extend_paramsextend_params)【免费下载链接】pyptoPyPTO发音: pai p-t-oParallel Tensor/Tile Operation编程范式。项目地址: https://gitcode.com/cann/pypto创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2633614.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!