Pytorch官方FlashAttention速度测试

news2025/6/26 11:57:43

在Pytorch的2.2版本更新文档中,官方重点强调了通过实现FlashAtteneion-v2实现了对scaled_dot_product_attention约2X左右的加速。
在这里插入图片描述
今天抽空亲自试了下,看看加速效果是否如官方所说。测试前需要将Pytorch的版本更新到2.2及以上,下面是测试代码,一个是原始手写的Self-Attention的实现,一个是使用Pytorch官方的scaled_dot_product_attention接口:

import time
import torch
import torch.nn.functional as F


def main():
    repeat = 100
    device = torch.device("cuda:0")
    dtype = torch.float16

    query = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
    key = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
    value = torch.rand(32, 8, 128, 64, dtype=dtype, device=device)
    scale_factor = 0.125

    ori_time_list = []
    for _ in range(repeat):
        torch.cuda.synchronize(device=device)
        time_start = time.perf_counter()
        # 原始Self-Attention实现
        res = torch.softmax(query @ key.transpose(-2, -1) * scale_factor, dim=-1) @ value
        torch.cuda.synchronize(device=device)
        time_end = time.perf_counter()
        ori_time_list.append(time_end - time_start)

    fa_time_list = []
    for _ in range(repeat):
        torch.cuda.synchronize(device=device)
        time_start = time.perf_counter()
        with torch.backends.cuda.sdp_kernel(enable_math=False):
            # 使用Pytorch官方提供的FA实现
            res_fa = F.scaled_dot_product_attention(query, key, value, scale=scale_factor)
        torch.cuda.synchronize(device=device)
        time_end = time.perf_counter()
        fa_time_list.append(time_end - time_start)

    diff = (res - res_fa).abs().max()
    ratio = [ori_time_list[i] / fa_time_list[i] for i in range(repeat)]
    avg_ratio = sum(ratio[1:]) / len(ratio[1:])
    print(f"max diff: {diff}")
    print(f"avg speed up ratio: {avg_ratio}")


if __name__ == '__main__':
    main()

执行以上代码,终端输出如下:

max diff: 0.00048828125
avg speed up ratio: 2.2846881043417118

这里使用的设备是RTX4070,跑了很多次发现确实加速2X左右,看来以后训练或者推理时可以考虑直接使用官方的scaled_dot_product_attention接口了。但是这里也发现了两个问题,一个是原始手写的Self-Attention的计算结果和直接调用scaled_dot_product_attention接口得到的结果差异有点大(注意,这里计算的Tensor都是FP16精度的),如果我切换到FP32精度差异会再小两个数量级。第二个问题是如果使用FP32的话实测没有明显加速,这个就很奇怪了,官方文档里并没有说专门针对FP16精度优化的。关于这两个问题,暂时猜测是环境问题,或许换个GPU硬件设备或者更新下驱动啥的就可能没有这些问题了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1584155.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【opencv】示例-facial_features.cpp 使用Haarcascade分类器检测面部特征点

// 包含OpenCV库中有关对象检测的头文件 #include "opencv2/objdetect.hpp" // 包含OpenCV库中有关高层GUI函数的头文件 #include "opencv2/highgui.hpp" // 包含OpenCV库中有关图片处理的头文件 #include "opencv2/imgproc.hpp"// 包含输入输出…

Vue的学习之旅-part6-循环的集中写法与ES6增强语法

Vue的学习之旅-循环的集中写法与ES6增强语法 vue中的几种循环写法for循环for in 循环 for(let i in data){}for of 循环 for(let item of data){}reduce() 遍历 reduce( function( preValue, item){} , 0 ) ES6增强写法 类似语法糖简写对象简写函数简写 动态组件中使用 <kee…

Web漏洞-文件上传之内容逻辑数组

图片一句话制作方法&#xff1a; copy 1.png /b shell.php /a webshell.jpg 具体示例见upload-labs 的14-17 二次渲染----见Pass-18 用/.或者%00绕过&#xff1a;Pass-20----Pass-21 CVE-2017-12615复现 创好环境后打开环境&#xff0c;再访问ip8080 抓包发送数据 Shell的…

M1 Flutter SDK的安装和环境配置

前言 作为iOS 开发&#xff0c;观望了许久的Flutter &#xff0c;还是对它下手了&#xff0c;不是故意要卷&#xff0c;没办法工作需要&#xff01;既然要学Flutter&#xff0c;首先就得配置Flutter的相关环境&#xff0c;由于我的是M1 芯片的电脑&#xff0c;记录下来配置过程…

四川古力未来科技抖音小店:安全守护,购物无忧

在当下数字化浪潮席卷全球的背景下&#xff0c;电商行业迎来了前所未有的发展机遇。四川古力未来科技抖音小店作为新兴的电商力量&#xff0c;以其独特的魅力和强大的安全保障措施&#xff0c;赢得了广大消费者的青睐和信任。本文将深入探讨四川古力未来科技抖音小店在安全方面…

java+saas模式医院云HIS系统源码Java+Spring+MySQL + MyCat融合BS版电子病历系统,支持电子病历四级

javasaas模式医院云HIS系统源码JavaSpringMySQL MyCat融合BS版电子病历系统&#xff0c;支持电子病历四级 云HIS系统是一款满足基层医院各类业务需要的健康云产品。该产品能帮助基层医院完成日常各类业务&#xff0c;提供病患预约挂号支持、病患问诊、电子病历、开药发药、会员…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之八 简单视频素描效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之八 简单视频素描效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单视频处理实战案例 之八 简单视频素描效果 一、简单介绍 二、简单指定视频某片段快放效果实现原理 三、简单指定视频某…

盲人出行新篇章:一款悄然改变生活的盲人导航应用

作为一名资深记者&#xff0c;我始终关注并报道那些以科技创新改善特殊群体生活质量的故事。近期&#xff0c;一款名为蝙蝠避障的专业盲人导航辅助工具引起了我的关注。它凭借独树一帜的避障技术&#xff0c;悄然间为视障群体的独立出行开启了全新篇章&#xff0c;带来了显著且…

HarmonyOS实战开发-本示例模拟倒计时场景,如何实现振动。

介绍 本示例模拟倒计时场景&#xff0c;通过ohos.vibrator 等接口来实现振动。 效果预览 使用说明 1.点击倒计时文本&#xff0c;弹出时间选择框&#xff0c;选择任意时间&#xff0c;点击确认&#xff0c;倒计时文本显示选择的时间。 2.点击start&#xff0c;开始倒计时&a…

【JavaEE初阶系列】——网络初识—TCP/IP五层网络模型

目录 &#x1f6a9;网络的发展史 &#x1f388;局域网LAN &#x1f388;广域网WAN &#x1f6a9;网络通信基础 &#x1f388;IP地址 &#x1f388;端口号 &#x1f388;协议类型 &#x1f388;五元组 &#x1f6a9;协议分层 &#x1f388;什么是协议分层 &#x…

查看TensorFlow已训模型的结构和网络参数

文章目录 概要流程 概要 通过以下实例&#xff0c;你将学会如何查看神经网络结构并打印出训练参数。 流程 准备一个简易的二分类数据集&#xff0c;并编写一个单层的神经网络 train_data np.array([[1, 2, 3, 4, 5], [7, 7, 2, 4, 10], [1, 9, 3, 6, 5], [6, 7, 8, 9, 10]]…

【opencv】示例-essential_mat_reconstr.cpp 从两幅图像中恢复3D场景的几何信息

导入OpenCV的calib3d, highgui, imgproc模块以及C的vector, iostream, fstream库。定义了getError2EpipLines函数&#xff0c;这个函数用来计算两组点相对于F矩阵&#xff08;基础矩阵&#xff09;的投影误差。定义了sgn函数&#xff0c;用于返回一个双精度浮点数的符号。定义了…

SQLite超详细的编译时选项(十六)

返回&#xff1a;SQLite—系列文章目录 上一篇&#xff1a;SQLite数据库文件格式&#xff08;十五&#xff09; 下一篇&#xff1a;SQLite 在Android安装与定制方案&#xff08;十七&#xff09; 1. 概述 对于大多数目的&#xff0c;SQLite可以使用默认的 编译选项。但是…

2.HTML常用标签之表单标签

1.HTML常用标签之表单标签 w3c所有标签列表 HTML常用标签之表单标签

结合 tensorflow.js 、opencv.js 与 Ant Design 创建美观且高性能的人脸动捕组件并发布到InsCode

系列文章目录 如何在前端项目中使用opencv.js | opencv.js入门如何使用tensorflow.js实现面部特征点检测tensorflow.js 如何从 public 路径加载人脸特征点检测模型tensorflow.js 如何使用opencv.js通过面部特征点估算脸部姿态并绘制示意图tensorflow.js 使用 opencv.js 将人脸…

【STM32篇】DRV8425驱动步进电机

【STM32篇】4988驱动步进电机_hr4988-CSDN博客 在上篇文章中使用了HR4988实现了步进电机的驱动&#xff0c;在实际运用过程&#xff0c;HR4988或者A4988驱动步进电机会存在电机噪音太大的现象。本次将向各位友友介绍一个驱动简单且非常静音的一款步进电机驱动IC。 1.DRV8425简介…

苹果开发者后台添加udid后,xcode中 Devices 数量没有更新问题

删除 文件夹 /Users/…/Library/MobileDevice/Provisioning Profiles 如何打开&#xff1a;https://zhuanlan.zhihu.com/p/563928113 回到Xcode刷新包名下面的警告验证&#xff08;可能需要翻墙&#xff09; 完毕&#xff01;

Java异常处理机制详解:多层方法调用与异常传播(day23)

1.数组下标越界 2.多个处理异常 上面这两个代码的区别就是有无 System.out.println("抛出了NumberFormatException"); System.out.println("抛出了ArrayIndexOutOfBoundsException"); 第一种是不论捕获到哪种异常&#xff0c;都只会调用e.printStack…

探索GlusterFS:开源分布式文件系统

目录 引言 一、GlusterFS简介 &#xff08;一&#xff09;基本介绍 &#xff08;二&#xff09;GlusterFS特点 &#xff08;三&#xff09;GlusterFS术语 &#xff08;四&#xff09;GlusterFS工作流程 二、GlusterFs的卷类型 &#xff08;一&#xff09;卷类型 &…

【面试题】微博、百度等大厂的排行榜如何实现?

背景 现如今每个互联网平台都会提供一个排行版的功能&#xff0c;供人们预览最新最有热度的一些消息&#xff0c;比如百度&#xff1a; 再比如微博&#xff1a; 我们要知道&#xff0c;这些互联网平台每天产生的数据是非常大&#xff0c;如果我们使用MySQL的话&#xff0c;db实…