CANN量化注意力梯度算子
aclnnQuantFlashAttentionScoreGrad【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品/Atlas A3 推理系列产品xAtlas A2 训练系列产品/Atlas A2 推理系列产品xAtlas 200I/500 A2 推理产品×Atlas 推理系列产品×Atlas 训练系列产品×功能说明接口功能实现“Transformer Attention Score”的融合量化的反向计算。计算公式$$ YSoftmax(\frac{\hat{Q}\hat{K}^T*(dS_q*dS_k)}{\sqrt{d}})\hat{V}*dS_v $$为方便表达以变量$S$和$P$表示计算公式$$ S\frac{\hat{Q}\hat{K}^T*(dS_q * dS_k)}{\sqrt{d}} $$$$ PSoftmax(S) $$$$ YP\hat{V} * dS_v $$则注意力的反向计算公式为$$ \hat{dS} dS * dsScale $$$$ \hat{P} P * pScale $$$$ dV\hat{P}^T\hat{dY} * (dS_{dy} * dS_p) $$$$ dQ\frac{(\hat{(dS)}\hat{K})}{\sqrt{d}}(dS_{ds} * dS_k) $$$$ dK\frac{(\hat{(dS)}^T*\hat{Q})}{\sqrt{d}} * (dS_{ds} * dS_q) $$函数原型每个算子分为两段式接口必须先调用“aclnnQuantFlashAttentionScoreGradGetWorkspace”接口获取计算所需workspace大小以及包含了算子计算流程的执行器再调用“aclnnQuantFlashAttentionScoreGrad”接口执行计算。aclnnStatus aclnnQuantFlashAttentionScoreGradGetWorkspace( const aclTensor *query, const aclTensor *keyIn, const aclTensor *value, const aclTensor *dy, const aclTensor *attenMaskOptional, const aclTensor *softmaxMax, const aclTensor *softmaxSum, const aclTensor *attentionIn, const aclTensor *dScaleQ, const aclTensor *dScaleK, const aclTensor *dScaleV, const aclTensor *dScaleDy, const aclTensor *dsScale, const aclTensor *pScale, double scaleValue, int64_t preTokens, int64_t nextTokens, int64_t headNum, char *inputLayout, int64_t sparseMode, int64_t outDtype, aclTensor *dqOut, aclTensor *dkOut, aclTensor *dvOut, uint64_t *workspaceSize, aclOpExecutor **executor)aclnnStatus aclnnQuantFlashAttentionScoreGrad( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)aclnnQuantFlashAttentionScoreGradGetWorkspace参数说明参数名输入/输出描述使用说明数据类型数据格式维度(shape)非连续Tensorquery输入公式中的Q。数据类型与keyIn/value一致。HIFLOAT8ND4√keyIn输入公式中的K。数据类型与query/value一致。HIFLOAT8ND4√value输入公式中的V。数据类型与query/keyIn一致。HIFLOAT8ND4√dy输入公式中的dY。-HIFLOAT8ND4√attenMaskOptional可选输入暂不使用BOOL、UINT8ND4√softmaxMax输入注意力正向计算的中间输出。shape[B,N,Sq,1]。FLOAT32ND4√softmaxSum输入注意力正向计算的中间输出。shape[B,N,Sq,1]。FLOAT32ND4√attentionIn输入注意力正向的最终输出。数据类型和shape与query一致。BFLOAT16ND4√dScaleQ输入是query输入的反量化参数。支持[B,N1,Ceil(Sq/blocksize),1], blocksize目前支持512FLOAT32ND4√dScaleK输入是key输入的反量化参数。支持[B,N2,Ceil(Skv/blocksize),1], blocksize目前支持512FLOAT32ND4√dScaleV输入是value输入的反量化参数。支持[B,N2,Ceil(Skv/blocksize),1], blocksize目前支持512FLOAT32ND4√dScaleDy输入是dy输入的反量化参数。支持[B,N1,Ceil(Sq/blocksize),1], blocksize目前支持512FLOAT32ND4√dsScale输入是ds的量化参数。支持[1]FLOAT32ND1√pScale输入是p的量化参数。支持[1]FLOAT32ND1√scaleValue输入公式中的scale缩放系数默认值为1。-DOUBLE---preTokens可选输入暂不使用。-INT64---nextTokens可选输入暂不使用。-INT64---headNum输入单卡head个数对应query的N轴。-INT64---inputLayout输入query/key/value的数据排布格式。支持BSND。String---sparseMode可选输入暂不使用。-INT64---outDtype输入值为0表示dqOut等输出是FLOAT16为1表示是BFLOAT16。-INT64---dqOut输出公式中的dQquery的梯度。-BFLOAT16ND4√dkOut输出公式中的dKkeyIn的梯度。-BFLOAT16ND4√dvOut输出公式中的dVvalue的梯度。-BFLOAT16ND4√workspaceSize输出返回Device侧需要申请的workspace大小。-----executor输出返回算子执行器包含计算流程。-----返回值返回aclnnStatus状态码具体参见aclnn返回码。第一段接口完成入参校验出现以下场景时报错返回码错误码描述ACLNN_ERR_PARAM_NULLPTR161001传入参数是必选输入输出或者必选属性且是空指针。ACLNN_ERR_PARAM_INVALID161002query、keyIn、value、dy、softmaxMax、softmaxSum、attentionIn、dScaleQ、dScaleK、dScaleV、dScaleDy、dqOut、dkOut、dvOut的数据类型和shape不在支持的范围内。aclnnQuantFlashAttentionScoreGrad参数说明参数名输入/输出描述workspace输入在Device侧申请的workspace内存地址。workspaceSize输入在Device侧申请的workspace大小由第一段接口aclnnQuantFlashAttentionScoreGradGetWorkspaceSize获取。executor输入op执行器包含了算子计算流程。stream输入指定执行任务的Stream。返回值aclnnStatus返回状态码具体参见aclnn返回码。约束说明确定性计算aclnnQuantFlashAttentionScoreGrad默认确定性实现。输入query、key、value、dy的约束如下Bbatchsize必须相等。inputLayout必须一致。D支持128。输入query/dy的N和key/value的N必须相等。关于数据shape的约束目前支持以下场景LayoutQueryShapeKeyShapeBSND[1, 54000, 5, 128][1, 54000, 5, 128]BSND[1, 9360, 40, 128][1, 9360, 40, 128]BSND[1, 54000, 10, 128][1, 54000, 10, 128]BSND[1, 9360, 80, 128][1, 9360, 80, 128]BSND[1, 57600, 5, 128][1, 57600, 5, 128]BSND[1, 7200, 40, 128][1, 512, 40, 128]部分场景下如果计算量过大可能会导致算子执行超时aicore error类型报错errorStr为timeout or trap error此时建议做轴切分处理注这里的计算量会受B、S、N、D等参数的影响值越大计算量越大。关于softmaxMax与softmaxSum参数的约束输入格式固定为[B, N, S, 1]。headNum的取值必须和传入的Query中的N值保持一致。调用示例示例代码如下仅供参考具体编译和执行过程请参考编译与运行样例。#include iostream #include vector #include cstdint #include cmath #include acl/acl.h #include aclnnop/aclnn_flash_attention_score_grad.h #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while (0) int64_t GetShapeSize(const std::vectorint64_t shape) { int64_t shapeSize 1; for (auto i : shape) { shapeSize * i; } return shapeSize; } void PrintOutResult(std::vectorint64_t shape, void** deviceAddr) { auto size GetShapeSize(shape); std::vectorfloat resultData(size, 0); auto ret aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(copy result from device to host failed. ERROR: %d\n, ret); return); for (int64_t i 0; i size; i) { LOG_PRINT(mean result[%ld] is: %f\n, i, resultData[i]); } } int Init(int32_t deviceId, aclrtStream* stream) { // 固定写法资源初始化 auto ret aclInit(nullptr); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclInit failed. ERROR: %d\n, ret); return ret); ret aclrtSetDevice(deviceId); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtSetDevice failed. ERROR: %d\n, ret); return ret); ret aclrtCreateStream(stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtCreateStream failed. ERROR: %d\n, ret); return ret); return 0; } template typename T int CreateAclTensor(const std::vectorT hostData, const std::vectorint64_t shape, void** deviceAddr, aclDataType dataType, aclTensor** tensor) { auto size GetShapeSize(shape) * sizeof(T); // 调用aclrtMalloc申请device侧内存 auto ret aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtMalloc failed. ERROR: %d\n, ret); return ret); // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上 ret aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtMemcpy failed. ERROR: %d\n, ret); return ret); // 计算连续tensor的strides std::vectorint64_t strides(shape.size(), 1); for (int64_t i shape.size() - 2; i 0; i--) { strides[i] shape[i 1] * strides[i 1]; } // 调用aclCreateTensor接口创建aclTensor *tensor aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } int main() { // 1. 固定写法device/stream初始化参考acl API手册 // 根据自己的实际device填写deviceId int32_t deviceId 0; aclrtStream stream; auto ret Init(deviceId, stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(Init acl failed. ERROR: %d\n, ret); return ret); // 2. 构造输入与输出需要根据API的接口自定义构造 int64_t B 1; int64_t N1 40; int64_t N2 40; int64_t S1 7200; int64_t S2 512; int64_t D 128; int64_t H1 N1 * D; int64_t H2 N2 * D; int64_t blockNumQ (S1 511)/ 512; int64_t blockNumKV (S2 511)/ 512; int64_t q_size B * N1 * S1 * D; int64_t kv_size B * N2 * S2 * D; int64_t softmax_size B * N1 * S1 * 1; int64_t scaleSizeQ B * N1 * blockNumQ * 1; int64_t scaleSizeKV B * N1 * blockNumKV * 1; std::vectorint64_t qShape {B, S1, N1, D}; std::vectorint64_t kShape {B, S2, N2, D}; std::vectorint64_t vShape {B, S2, N2, D}; std::vectorint64_t dxShape {B, S1, N1, D}; std::vectorint64_t attenmaskShape {S1, S2}; std::vectorint64_t softmaxMaxShape {B, N1, S1, 1}; std::vectorint64_t softmaxSumShape {B, N1, S1, 1}; std::vectorint64_t attentionInShape {B, S1, N1, D}; std::vectorint64_t dScaleQShape {B, N1, blockNumQ, 1}; std::vectorint64_t dScaleKShape {B, N1, blockNumKV, 1}; std::vectorint64_t dScaleVShape {B, N1, blockNumKV, 1}; std::vectorint64_t dScaleDyShape {B, N1, blockNumQ, 1}; std::vectorint64_t dsScaleShape {1}; std::vectorint64_t pScaleShape {1}; std::vectorint64_t dqShape {B, S1, N1, D}; std::vectorint64_t dkShape {B, S2, N2, D}; std::vectorint64_t dvShape {B, S2, N2, D}; std::vectorint64_t printShape {B, S2, 1, D}; void* qDeviceAddr nullptr; void* kDeviceAddr nullptr; void* vDeviceAddr nullptr; void* dxDeviceAddr nullptr; void* softmaxMaxDeviceAddr nullptr; void* softmaxSumDeviceAddr nullptr; void* attentionInDeviceAddr nullptr; void* dScaleQDeviceAddr nullptr; void* dScaleKDeviceAddr nullptr; void* dScaleVDeviceAddr nullptr; void* dScaleDyDeviceAddr nullptr; void* dsScaleDeviceAddr nullptr; void* pScaleDeviceAddr nullptr; void* dqDeviceAddr nullptr; void* dkDeviceAddr nullptr; void* dvDeviceAddr nullptr; aclTensor* q nullptr; aclTensor* k nullptr; aclTensor* v nullptr; aclTensor* dx nullptr; aclTensor* attenmask nullptr; aclTensor* softmaxMax nullptr; aclTensor* softmaxSum nullptr; aclTensor* attentionIn nullptr; aclTensor* dScaleQ nullptr; aclTensor* dScaleK nullptr; aclTensor* dScaleV nullptr; aclTensor* dScaleDy nullptr; aclTensor* dsScale nullptr; aclTensor* pScale nullptr; aclTensor* dq nullptr; aclTensor* dk nullptr; aclTensor* dv nullptr; std::vectoruint8_t qHostData(q_size, 1); std::vectoruint8_t kHostData(kv_size, 1); std::vectoruint8_t vHostData(kv_size, 1); std::vectoruint8_t dxHostData(q_size, 1); std::vectorfloat softmaxMaxHostData(softmax_size, 3.0); std::vectorfloat softmaxSumHostData(softmax_size, 3.0); std::vectorfloat attentionInHostData(q_size, 1.0); std::vectorfloat dScaleQHostData(scaleSizeQ, 1.0); std::vectorfloat dScaleKHostData(scaleSizeKV, 1.0); std::vectorfloat dScaleVHostData(scaleSizeKV, 1.0); std::vectorfloat dScaleDyHostData(scaleSizeQ, 1.0); std::vectorfloat dsScaleHostData(1, 1.0); std::vectorfloat pScaleHostData(1, 1.0); std::vectorfloat dqHostData(q_size, 0); std::vectorfloat dkHostData(kv_size, 0); std::vectorfloat dvHostData(kv_size, 0); ret CreateAclTensor(qHostData, qShape, qDeviceAddr, aclDataType::ACL_HIFLOAT8, q); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(kHostData, kShape, kDeviceAddr, aclDataType::ACL_HIFLOAT8, k); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(vHostData, vShape, vDeviceAddr, aclDataType::ACL_HIFLOAT8, v); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dxHostData, dxShape, dxDeviceAddr, aclDataType::ACL_HIFLOAT8, dx); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(softmaxMaxHostData, softmaxMaxShape, softmaxMaxDeviceAddr, aclDataType::ACL_FLOAT, softmaxMax); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(softmaxSumHostData, softmaxSumShape, softmaxSumDeviceAddr, aclDataType::ACL_FLOAT, softmaxSum); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(attentionInHostData, attentionInShape, attentionInDeviceAddr, aclDataType::ACL_BF16, attentionIn); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dScaleQHostData, dScaleQShape, dScaleQDeviceAddr, aclDataType::ACL_FLOAT, dScaleQ); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dScaleKHostData, dScaleKShape, dScaleKDeviceAddr, aclDataType::ACL_FLOAT, dScaleK); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dScaleVHostData, dScaleVShape, dScaleVDeviceAddr, aclDataType::ACL_FLOAT, dScaleV); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dScaleDyHostData, dScaleDyShape, dScaleDyDeviceAddr, aclDataType::ACL_FLOAT, dScaleDy); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dsScaleHostData, dsScaleShape, dsScaleDeviceAddr, aclDataType::ACL_FLOAT, dsScale); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(pScaleHostData, pScaleShape, pScaleDeviceAddr, aclDataType::ACL_FLOAT, pScale); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dqHostData, dqShape, dqDeviceAddr, aclDataType::ACL_BF16, dq); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dkHostData, dkShape, dkDeviceAddr, aclDataType::ACL_BF16, dk); CHECK_RET(ret ACL_SUCCESS, return ret); ret CreateAclTensor(dvHostData, dvShape, dvDeviceAddr, aclDataType::ACL_BF16, dv); CHECK_RET(ret ACL_SUCCESS, return ret); double scaleValue 1.0/sqrt(128); int64_t preTokens INT32_MAX; int64_t nextTokens INT32_MAX; int64_t headNum N1; int64_t sparseMode 0; char layOut[6] {B, S, N, D, 0}; int64_t outDtype 1; // 3. 调用CANN算子库API需要修改为具体的Api名称 uint64_t workspaceSize 0; aclOpExecutor* executor; // 调用aclnnFlashAttentionScoreGradV2第一段接口 ret aclnnQuantFlashAttentionScoreGradGetWorkspaceSize(q, k, v, dx, attenmask, softmaxMax, softmaxSum, attentionIn, dScaleQ, dScaleK, dScaleV,dScaleDy, dsScale, pScale, scaleValue, preTokens, nextTokens, headNum, layOut, sparseMode, outDtype, dq, dk, dv, workspaceSize, executor); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclnnQuantFlashAttentionScoreGradGetWorkspaceSize failed. ERROR: %d\n, ret); return ret); // 根据第一段接口计算出的workspaceSize申请device内存 void* workspaceAddr nullptr; if (workspaceSize 0) { ret aclrtMalloc(workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(allocate workspace failed. ERROR: %d\n, ret); return ret); } // 调用aclnnFlashAttentionScoreGradV2第二段接口 ret aclnnQuantFlashAttentionScoreGrad(workspaceAddr, workspaceSize, executor, stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclnnFlashAttentionScoreGradV2 failed. ERROR: %d\n, ret); return ret); // 4. 固定写法同步等待任务执行结束 ret aclrtSynchronizeStream(stream); CHECK_RET(ret ACL_SUCCESS, LOG_PRINT(aclrtSynchronizeStream failed. ERROR: %d\n, ret); return ret); // 5. 获取输出的值将device侧内存上的结果拷贝至host侧需要根据具体API的接口定义修改 PrintOutResult(printShape, dqDeviceAddr); PrintOutResult(printShape, dkDeviceAddr); PrintOutResult(printShape, dvDeviceAddr); // 6. 释放aclTensor和aclScalar需要根据具体API的接口定义修改 aclDestroyTensor(q); aclDestroyTensor(k); aclDestroyTensor(v); aclDestroyTensor(dx); aclDestroyTensor(attenmask); aclDestroyTensor(softmaxMax); aclDestroyTensor(softmaxSum); aclDestroyTensor(attentionIn); aclDestroyTensor(dScaleQ); aclDestroyTensor(dScaleK); aclDestroyTensor(dScaleV); aclDestroyTensor(dScaleDy); aclDestroyTensor(dsScale); aclDestroyTensor(pScale); aclDestroyTensor(dq); aclDestroyTensor(dk); aclDestroyTensor(dv); // 7. 释放device资源 aclrtFree(qDeviceAddr); aclrtFree(kDeviceAddr); aclrtFree(vDeviceAddr); aclrtFree(dxDeviceAddr); aclrtFree(softmaxMaxDeviceAddr); aclrtFree(softmaxSumDeviceAddr); aclrtFree(attentionInDeviceAddr); aclrtFree(dScaleQDeviceAddr); aclrtFree(dScaleKDeviceAddr); aclrtFree(dScaleVDeviceAddr); aclrtFree(dScaleDyDeviceAddr); aclrtFree(dsScaleDeviceAddr); aclrtFree(pScaleDeviceAddr); aclrtFree(dqDeviceAddr); aclrtFree(dkDeviceAddr); aclrtFree(dvDeviceAddr); if (workspaceSize 0) { aclrtFree(workspaceAddr); } aclrtDestroyStream(stream); aclrtResetDevice(deviceId); aclFinalize(); return 0; }【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2599141.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!