FlashAttention计算过程梳理

news2025/7/21 13:20:18

FlashAttention 的速度优化原理是怎样的?
从 FlashAttention 到 PagedAttention, 如何进一步优化 Attention 性能
FlashAttention图解(如何加速Attention)
FlashAttention开源代码
Transformer Block运算量解析

在这里插入图片描述

  在self-attention模块中,主要包含全连接层(通过矩阵乘法实现)、softmax(计算注意力权重),以及根据注意力权重的加权求和(计算注意力的输出结果)。其中,全连接层和根据注意力权重的加权求和其实都是通过矩阵乘法实现的,所以分块计算可以通过矩阵的分块乘法来实现。由于softmax的分母部分需要计算全局元素的求和,分块之后只能计算局部的和,导致softmax的分块计算变得复杂。

  • 标准版softmax
    s o f t m a x ( x ) = e x i ∑ e x j softmax(x)=\frac{e^{x_i} }{\sum e^{x_j}} softmax(x)=exjexi

  • 稳定版softmax
    s o f t m a x ( x ) = e x i − m a x ( x ) ∑ e x j − m a x ( x ) softmax(x)=\frac{e^{x_i - max(x)} }{\sum e^{x_j - max(x)}} softmax(x)=exjmax(x)eximax(x)
      其中, m a x ( x ) max(x) max(x)表示 x x x 中的最大值。

  • 分块计算softmax

  1. 将数值序列 x x x 分成不同的块 x ( 1 ) , x ( 2 ) , . . . , x ( n ) x^{(1)},x^{(2)},...,x^{(n)} x(1),x(2),...,x(n)
  2. 使用稳定版softmax计算第一块 x ( 1 ) x^{(1)} x(1) 的结果,同时记录下第一块的最大值 m ( x ( 1 ) ) m(x^{(1)}) m(x(1)) 和第一块的局部求和结果 l ( x ( 1 ) ) = ∑ e x ( 1 ) − m ( x ( 1 ) ) l(x^{(1)}) = \sum {e^{x^{(1)} - m(x^{(1)})}} l(x(1))=ex(1)m(x(1))
  3. 设置变量 m m a x m_{max} mmax 记录迭代到此的全局最大值,设置变量 l a l l l_{all} lall 记录迭代到此的全局求和结果,后续随着迭代计算不同的分块 x ( i ) x^{(i)} x(i) 逐步更新 m m a x m_{max} mmax l a l l l_{all} lall。计算完第一块之后 m m a x = m ( x ( 1 ) ) m_{max} = m(x^{(1)}) mmax=m(x(1)) , l a l l = l ( x ( 1 ) ) l_{all} = l(x^{(1)}) lall=l(x(1))
  4. 使用稳定版softmax计算第二块 x ( 2 ) x^{(2)} x(2) 的结果,得到 m ( x ( 2 ) ) m(x^{(2)}) m(x(2)) l ( x ( 2 ) ) = ∑ e x ( 2 ) − m ( x ( 2 ) ) l(x^{(2)}) = \sum {e^{x^{(2)} - m(x^{(2)})}} l(x(2))=ex(2)m(x(2))
  5. 更新迭代到此时的全局最大值 m m a x n e w = m a x ( m m a x , m ( x ( 2 ) ) ) m_{max}^{new} = max(m_{max}, m(x^{(2)})) mmaxnew=max(mmax,m(x(2)))
  6. 更新迭代到此时的全局求和结果 l a l l n e w = e m m a x − m m a x n e w ∗ l a l l + e m ( x ( 2 ) ) − m m a x n e w ∗ l ( x ( 2 ) ) l_{all}^{new} = e^{m_{max} - m_{max}^{new}}*l_{all} + e^{m(x^{(2)}) - m_{max}^{new}}*l(x^{(2)}) lallnew=emmaxmmaxnewlall+em(x(2))mmaxnewl(x(2))

  关于第 6 步的公式是怎么得到的,我们把第 6 步的公式拆解为两部分,现在我们计算到了第二块数据 x ( 2 ) x^{(2)} x(2),所以我们此时的全局求和结果由两部分组成,第一部分是由 x ( 1 ) x^{(1)} x(1) 数据块产生的求和结果,第二部分是由 x ( 2 ) x^{(2)} x(2) 数据块产生的求和结果,但是 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2) 计算的求和结果分别使用的是各自局部的最大值 m a x ( x ) max(x) max(x) 进行计算的,所以要将 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2) 的局部求和结果更新为当前阶段的全局求和结果。

  以更新 x ( 2 ) x^{(2)} x(2) 的求和结果为例,在计算 x ( 2 ) x^{(2)} x(2) 的softmax的过程中,分子分母同时除以了 x ( 2 ) x^{(2)} x(2) 的局部最大值的 e m ( x ( 2 ) ) e^{m(x^{(2)})} em(x(2)),所以现在要对分母部分 x ( 2 ) x^{(2)} x(2) 的局部求和结果进行还原,先乘以局部最大值的 e m ( x ( 2 ) ) e^{m(x^{(2)})} em(x(2)),然后在除以全局的最大值的 e m m a x n e w e^{m_{max}^{new}} emmaxnew,公式表示如下:
l ( x ( 2 ) ) n e w = ∑ e x ( 2 ) − m ( x ( 2 ) ) ∗ e m ( x ( 2 ) ) e m m a x n e w = ∑ e x ( 2 ) − m m a x n e w = e m ( x ( 2 ) ) − m m a x n e w ∗ l ( x ( 2 ) ) l(x^{(2)})_{new} = \frac {\sum {e^{x^{(2)} - m(x^{(2)})}} * e^{m(x^{(2)})}}{e^{m_{max}^{new}}} = \sum {e^{x^{(2)} - m_{max}^{new}}} = e^{m(x^{(2)}) - m_{max}^{new}}*l(x^{(2)}) l(x(2))new=emmaxnewex(2)m(x(2))em(x(2))=ex(2)mmaxnew=em(x(2))mmaxnewl(x(2))
  同理,也可以使用迭代到此时的全局的最大值的 e m m a x n e w e^{m_{max}^{new}} emmaxnew ,更新数据块 $x^{(1)} 的局部求和结果为迭代到此时的全局求和结果 $ l ( x ( 1 ) ) n e w l(x^{(1)})_{new} l(x(1))new,表示如下:
l ( x ( 1 ) ) n e w = ∑ e x ( 1 ) − m ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) e m m a x n e w = ∑ e x ( 1 ) − m m a x n e w = e m ( x ( 1 ) ) − m m a x n e w ∗ l ( x ( 1 ) ) l(x^{(1)})_{new} = \frac {\sum {e^{x^{(1)} - m(x^{(1)})}} * e^{m(x^{(1)})}}{e^{m_{max}^{new}}} = \sum {e^{x^{(1)} - m_{max}^{new}}} = e^{m(x^{(1)}) - m_{max}^{new}}*l(x^{(1)}) l(x(1))new=emmaxnewex(1)m(x(1))em(x(1))=ex(1)mmaxnew=em(x(1))mmaxnewl(x(1))

  所以,迭代到此时的全局求和结果就是 l a l l n e w = l ( x ( 1 ) ) n e w + l ( x ( 2 ) ) n e w l_{all}^{new} = l(x^{(1)})_{new} + l(x^{(2)})_{new} lallnew=l(x(1))new+l(x(2))new ,表示如下:
l a l l n e w = l ( x ( 1 ) ) n e w + l ( x ( 2 ) ) n e w = e m ( x ( 1 ) ) − m m a x n e w ∗ l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m m a x n e w ∗ l ( x ( 2 ) ) l_{all}^{new} = l(x^{(1)})_{new} + l(x^{(2)})_{new} = e^{m(x^{(1)}) - m_{max}^{new}}*l(x^{(1)}) + e^{m(x^{(2)}) - m_{max}^{new}}*l(x^{(2)}) lallnew=l(x(1))new+l(x(2))new=em(x(1))mmaxnewl(x(1))+em(x(2))mmaxnewl(x(2))

  因为在执行完数据块 x ( 1 ) x^{(1)} x(1) 之后,我们保存了 m m a x = m ( x ( 1 ) ) m_{max} = m(x^{(1)}) mmax=m(x(1)) , l a l l = l ( x ( 1 ) ) l_{all} = l(x^{(1)}) lall=l(x(1)) ,替换 m ( x ( 1 ) ) m(x^{(1)}) m(x(1)) l ( x ( 1 ) ) l(x^{(1)}) l(x(1)) ,所以上式就等价为:
l a l l n e w = l ( x ( 1 ) ) n e w + l ( x ( 2 ) ) n e w = e m m a x − m m a x n e w ∗ l a l l + e m ( x ( 2 ) ) − m m a x n e w ∗ l ( x ( 2 ) ) l_{all}^{new} = l(x^{(1)})_{new} + l(x^{(2)})_{new} = e^{m_{max} - m_{max}^{new}}*l_{all} + e^{m(x^{(2)}) - m_{max}^{new}}*l(x^{(2)}) lallnew=l(x(1))new+l(x(2))new=emmaxmmaxnewlall+em(x(2))mmaxnewl(x(2))
  上面这个公式,也就是上面第 6 步得到的公式。现在我们得到的 m a x m a x n e w max_{max}^{new} maxmaxnew 就是迭代到当前数据块的全局最大值, l a l l n e w l_{all}^{new} lallnew 就是迭代到当前数据块softmax分母部分的全局求和结果。

  1. 现在softmax的分母已经被更新成了全局的结果,现在就要把分子也更新成全局的结果就行了。分子的更新结果很简单,还是以更新 x ( 2 ) x^{(2)} x(2) 的分子为例,在计算 x ( 2 ) x^{(2)} x(2) 的softmax的过程中,分子分母同时除以了 x ( 2 ) x^{(2)} x(2) 的局部最大值的 e m ( x ( 2 ) ) e^{m(x^{(2)})} em(x(2)),所以现在要对分子的结果进行还原,先乘以局部最大值的 e m ( x ( 2 ) ) e^{m(x^{(2)})} em(x(2)),然后在除以全局的最大值的 e m m a x n e w e^{m_{max}^{new}} emmaxnew,公式表示如下:
    e x ( 2 ) − m ( x ( 2 ) ) ∗ e m ( x ( 2 ) ) e m m a x n e w = f ( x ( 2 ) ) ∗ e m ( x ( 2 ) ) − m m a x n e w \frac {e^{x^{(2)} - m(x^{(2)})} * e^{m(x^{{(2)}})}}{e^{m_{max}^{new}}} = f(x^{(2)})*e^{m(x^{(2)})-m_{max}^{new}} emmaxnewex(2)m(x(2))em(x(2))=f(x(2))em(x(2))mmaxnew
      同理,更新后 x ( 1 ) x^{(1)} x(1) 的分子如下:
    e x ( 1 ) − m ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) e m m a x n e w = f ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) − m m a x n e w \frac {e^{x^{(1)} - m(x^{(1)})} * e^{m(x^{{(1)}})}}{e^{m_{max}^{new}}} = f(x^{(1)})*e^{m(x^{(1)})-m_{max}^{new}} emmaxnewex(1)m(x(1))em(x(1))=f(x(1))em(x(1))mmaxnew

  2. 现在就可以计算 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2) 迭代到此时的“全局”softmax了。
    s o f t m a x ( x ( 1 ) ) n e w = f ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) − m m a x n e w l a l l n e w = s o f t m a x ( x ( 1 ) ) ∗ l ( x ( 1 ) ) ∗ e m ( x ( 1 ) ) − m m a x n e w l a l l n e w softmax(x^{(1)})_{new} = \frac{f(x^{(1)})*e^{m(x^{(1)})-m_{max}^{new}}}{l_{all}^{new}} = \frac{softmax(x^{(1)})*l(x^{(1)})*e^{m(x^{(1)})-m_{max}^{new}}}{l_{all}^{new}} softmax(x(1))new=lallnewf(x(1))em(x(1))mmaxnew=lallnewsoftmax(x(1))l(x(1))em(x(1))mmaxnew
    s o f t m a x ( x ( 2 ) ) n e w = f ( x ( 2 ) ) ∗ e m ( x ( 2 ) ) − m m a x n e w l a l l n e w = s o f t m a x ( x ( 2 ) ) ∗ l ( x ( 2 ) ) ∗ e m ( x ( 2 ) ) − m m a x n e w l a l l n e w softmax(x^{(2)})_{new} = \frac{f(x^{(2)})*e^{m(x^{(2)})-m_{max}^{new}}}{l_{all}^{new}} = \frac{softmax(x^{(2)})*l(x^{(2)})*e^{m(x^{(2)})-m_{max}^{new}}}{l_{all}^{new}} softmax(x(2))new=lallnewf(x(2))em(x(2))mmaxnew=lallnewsoftmax(x(2))l(x(2))em(x(2))mmaxnew
      上面公式中的 s o f t m a x ( x ( 1 ) ) , s o f t m a x ( x ( 2 ) ) , l ( x ( 1 ) ) , l ( x ( 2 ) ) , m ( x ( 1 ) ) , m ( x ( 2 ) ) , m m a x n e w softmax(x^{(1)}),softmax(x^{(2)}),l(x^{(1)}),l(x^{(2)}),m(x^{(1)}),m(x^{(2)}),m_{max}^{new} softmax(x(1)),softmax(x(2)),l(x(1)),l(x(2)),m(x(1)),m(x(2)),mmaxnew 等都是已知的中间结果,不用重新计算,也不用重新读取 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2) 数据块。

  3. 将经过数据块 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2) 计算得到的 m m a x n e w m_{max}^{new} mmaxnew l a l l n e w l_{all}^{new} lallnew 更新到 m m a x = m m a x n e w m_{max} = m_{max}^{new} mmax=mmaxnew l a l l = l a l l n e w l_{all} = l_{all}^{new} lall=lallnew,将数据块 x ( 1 ) x^{(1)} x(1) x ( 2 ) x^{(2)} x(2) 的计算结果看做一个整体作为 x ( 1 ) x^{(1)} x(1),将读取的新数据块 x ( 2 ) x^{(2)} x(2) 作为上面的 x ( 2 ) x^{(2)} x(2),继续迭代下去,直到完成所有数据块的计算,这样就得到了全局的softmax结果
    在这里插入图片描述

import numpy as np
import torch


def softmax(x):
    m_x = np.max(x)
    f_x = np.exp(x - m_x)
    l_x = np.sum(f_x)
    soft_x = f_x / l_x
    return m_x, f_x, l_x, soft_x


m_x1, f_x1, l_x1, soft_x1 = softmax(np.array([1, 2]))
m_x2, f_x2, l_x2, soft_x2 = softmax(np.array([3, 4]))
m_x_new = np.max([m_x1, m_x2])
l_new_all = np.exp(m_x1 - m_x_new) * l_x1 + np.exp(m_x2 - m_x_new) * l_x2
soft_x1_new = soft_x1 * l_x1 * np.exp(m_x1 - m_x_new) / l_new_all
soft_x2_new = soft_x2 * l_x2 * np.exp(m_x2 - m_x_new) / l_new_all
soft = torch.nn.functional.softmax(torch.Tensor([1, 2, 3, 4]), dim=0)

# [0.0320586  0.08714432] [0.23688282 0.64391426]
print(soft_x1_new, soft_x2_new)
# [0.0320586  0.08714432 0.23688284 0.6439143 ]
print(soft.numpy())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1102312.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【UE】安装下载的插件文件夹到虚幻引擎

比如我淘宝上购买了一个插件文件夹,解压后内容如下: 找到电脑上虚幻引擎(这里以UE5.1为例)的位置,可以看到里面有一个名字为“Plugins”的文件夹 在此文件夹中找到“Marketplace”文件夹 然后将下载的插件文件夹放到“…

智慧远程医疗服务:从零开始搭建互联网医院APP

互联网医院APP作为远程医疗服务的一部分,正在为患者和医生带来更便捷的医疗体验。本文将探讨如何从零开始构建一个互联网医院APP,包括关键步骤、技术要点和挑战。 一、确定项目目标和范围 在开始之前,您需要明确定义您的互联网医院APP的目标…

AI_Neural Network_Note (二)

NN Predict logistic regression 预测的过程其实只是based on 一个简单的逻辑回归logistic regression公式 z dot(w,x) b (x1 * w1 x2 * w2 x3 * w3) b dot(a,b): 向量a和向量b的点积(内积)运算。 点积是两个向量的对应分量相乘,并将…

什么是接口测试,接口测试怎么玩,接口自动化测试怎么玩?

前言 最近在找工作,因为是做纯服务端测试的,所以面试过程中面试官难免会问,怎么设计接口测试用例,怎么做接口自动化测试?会象征性的考一下基本功。 下面就接口测试,或者说服务端测试,梳理一下我…

Spring framework Day19:Spring AOP xml配置示例二

一、开始学习 1、新建项目&#xff0c;结构如下 2、添加 spring 依赖 <!-- spring 的核心依赖 --><dependencies><!-- https://mvnrepository.com/artifact/org.springframework/spring-context --><dependency><groupId>org.springframework&l…

2023,简历石沉大海?软件测试岗位真的已经饱和了....

各大互联网公司的接连裁员&#xff0c;政策限制的行业接连消失&#xff0c;让今年的求职雪上加霜&#xff0c;想躺平却没有资本&#xff0c;还有人说软件测试岗位饱和了&#xff0c;对此很多求职者深信不疑&#xff0c;因为投出去的简历回复的越来越少了。 另一面企业招人真的…

Redis数据结构之ziplist

前言 Redis 为了提高内存效率&#xff0c;设计了一种特殊的数据结构 ziplist&#xff08;压缩列表&#xff09;。ziplist 本质是一段字节数组&#xff0c;采用了一种紧凑的、连续存储的格式&#xff0c;可以有效地压缩数据&#xff0c;提高内存效率。 hash、zset 在数据量比较…

CSS 基础知识-01

CSS 基础知识 1.CSS概述2. CSS引入方式3. 选择器4.文字控制属性5. 复合选择器6. CSS 特性7.背景属性8.显示模式9.选择器10.盒子模型 1.CSS概述 2. CSS引入方式 3. 选择器 4.文字控制属性 5. 复合选择器 6. CSS 特性 7.背景属性 8.显示模式 9.选择器 <!DOCTYPE html> <…

AutoGPT:自动化GPT原理及应用实践

一、AutoGPT介绍 想象一下&#xff0c;生活在这样一个世界里&#xff0c;你有一个人工智能助手&#xff0c;它不仅能够理解你的需求&#xff0c;而且还能够与你一起学习与成长。人工智能已无缝融入我们工作、生活&#xff0c;并帮助我们有效完成各种目标。大模型技术的发展与应…

【数据分享】2022年我国30米分辨率的地形粗糙度(起伏度)数据(免费获取)

地形数据&#xff0c;也叫DEM数据&#xff0c;是我们在各项研究中最常使用的数据之一。之前我们分享过2022年哥白尼30米分辨率的DEM数据&#xff0c;该数据被公认为是全球最佳的开源DEM数据之一&#xff0c;甚至没有之一&#xff08;可查看之前的文章获悉详情&#xff09;&…

Jmeter的性能测试

性能测试的概念 定义&#xff1a;软件的性能是软件的一种非功能特性&#xff0c;它关注的不是软件是否能够完成特定的功能&#xff0c;而是在完成该功能时展示出来的及时性。 由定义可知性能关注的是软件的非功能特性&#xff0c;所以一般来说性能测试介入的时机是在功能测试…

particles 粒子背景插件在vue3中的使用

particles 粒子背景插件在vue3中的使用 概述使用完整代码概述 npm 链接 https://www.npmjs.com/package/particles.vue3 GitHub地址 https://github.com/tsparticles/vue3 配置参数说明: color: String类型 默认’#dedede’。粒子颜色。particleOpacity: Number类型 默认0.7。…

【Linux】线程互斥与同步

文章目录 一.Linux线程互斥1.进程线程间的互斥相关背景概念2互斥量mutex3.互斥量的接口4.互斥量实现原理探究 二.可重入VS线程安全1.概念2.常见的线程不安全的情况3.常见的线程安全的情况4.常见的不可重入的情况5.常见的可重入的情况6.可重入与线程安全联系7.可重入与线程安全区…

【halcon】halcon轮廓总结之select_contours_xld

前言 select_contours_xld 我认为是一个非常常用且实用的算子&#xff0c;用于对轮廓进行筛选。 简介 这段文档描述了一个名为"SelectContoursXld"的操作&#xff0c;用于根据不同特征选择XLD&#xff08;XLD是一种图像数据表示形式&#xff0c;表示轮廓线&#x…

使用 Bard 的 Google Hotel 插件查询酒店

使用 Bard 的 Google Hotel 插件&#xff0c;您可以通过以下步骤找到符合您需求的酒店&#xff1a; 在 Google 搜索中打开 Bard 插件。输入您要搜索的城市或酒店名称。选择您要搜索的日期和入住人数。选择您要搜索的酒店类型和价格范围。单击“搜索”按钮。 Find hotels for a…

OpenCV实战完美实现眨眼疲劳检测!!

目录 1&#xff0c;项目流程 2&#xff0c;代码实现 3&#xff0c;结果展示 应用场景主要是在监控系统和驾驶员安全监测中&#xff1a; 监控系统&#xff1a;可以将该项目应用于监控摄像头的视频流中&#xff0c;实时检测闭眼行为。通过实时计算闭眼次数和眼睛长宽比&#x…

ubuntu20.04安装FTP服务

安装 sudo apt-get install vsftpd# 设置开机启动并启动ftp服务 systemctl enable vsftpd systemctl start vsftpd#查看其运行状态 systemctl status vsftpd #重启服务 systemctl restart vsftpdftp用户 sudo useradd -d /home/ftp/ftptest -m ftptest sudo passwd ftptest…

数字签名 及 数字证书 原理笔记

这里是对 数字签名 及 数字证书 原理该视频做的一个笔记&#xff0c;链接 前言 如果对一些加密算法不懂可以参考这篇文章 数字签名 小明发送文件给小红时对文件做出签名 将文件进行hash算法加密得到hash值&#xff0c;并且对该hash值使用私钥进行加密&#xff08;私钥加密的…

接口加密解决方案:Python的各种加密实现!

01、前言 在现代软件开发中&#xff0c;接口测试已经成为了不可或缺的一部分。随着互联网的普及&#xff0c;越来越多的应用程序都采用了接口作为数据传输的方式。接口测试的目的是确保接口的正确性、稳定性和安全性&#xff0c;从而保障系统的正常运行。 在接口测试中&…