setup.py:python项目中,setup.py用于管理项目的构建、打包和分发过程。这个文件通常包含项目的元数据以及如何构建和安装模块的指令- 三个相关命令
- 构建扩展模块:
python setup.py build_ext - 清理构建文件:
python setup.py clean - 安装到系统:
python setup.py install。在项目根目录下,通过运行该命令来构建和安装你的包,这将会执行setup.py文件中的setup()函数,并根据其中的配置将包构建成一个分发包,并安装到python环境中
- 构建扩展模块:
- 运行
python setup.py install后发生的事情:- 环境检查:python检查setup里面列出的依赖项是否已经安装。若没有则尝试安装
- 构建包:使用
find_packages()找到所有可用的子模块并准备构建 - 编译扩展:如果有C/C++扩展模块,使用指定的构建工具(如Ninja)来编译这些扩展
- 安装包:将包和所有依赖项安装到python的
site-packages目录,使得包可以在python中被导入和使用 - 验证安装:安装完后,用户可在python环境中使用
import PACKAGE_NAME来验证安装是否成功
- 三个相关命令
也就是,
setup.py就是为了把编译后的结果打包成一个python包然后安装在环境当中的。setup.py其中包含了编译流程(ext_modules),等运行完之后,用户可在python环境中使用import PACKAGE_NAME来验证安装是否成功
setup(
name=PACKAGE_NAME,
version=get_package_version(),
packages=find_packages( // 用于查找包中可分发的所有子模块。exclude参数指定要排除的目录,这些目录不会被打包。通常会排除测试、文档和构建目录
exclude=(
"build",
"csrc",
"include",
"tests",
"dist",
"docs",
"benchmarks",
"flash_attn.egg-info",
)
),
author="Tri Dao",
author_email="tri@tridao.me",
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/Dao-AILab/flash-attention",
classifiers=[ // 一组字符串,用于提供关于包的元数据,比如python版本、许可证类型和操作系统
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: Unix",
],
ext_modules=ext_modules, // 指定C/C++扩展模块,如果没有扩展模块通常设为None。如果有C/C++扩展模块,就使用的构建工具(如Ninja)来编译这些扩展
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension} // 用于定义命令的字典
if ext_modules
else {
"bdist_wheel": CachedWheelsCommand,
},
python_requires=">=3.8",
install_requires=[
"torch",
"einops",
],
setup_requires=[
"packaging",
"psutil",
"ninja",
],
)
- “编译”与
ext_modules- 编译:
如上面所说,运行python setup.py install的过程会检查是否有C/C++扩展模块,若有的话就进行编译。
具体来说,编译扩展是将用C/C++编写的代码编译成共享库(动态链接库),这个库可以被python直接导入和使用。这使得python能够调用高性能的底层代码,通常用于加速计算密集型任务。
编译完成后,生成的共享库通常会是一个.so(Linux)、.dll(Windows)或.dylib(macOS)结尾的文件,这些文件可以在python中通过import语句直接导入。 ext_modules:是一个列表,包含了所有需要编译的扩展模块。通常由setuptools的Extension类构建(from setuptools import Extension)。这里是使用from torch.utils.cpp_extension import CUDAExtention。在setup()函数中,ext_modules参数指向这个扩展模块列表,当用户运行python setup.py install时,setuptools会读取这些信息,调用编译器进行编译。如果定义了多个扩展模块,它们会在同一次构建过程中被编译并链接到最终的python包中。
编译后的扩展模块可以被python代码直接调用,就像普通的python模块一样。
如下面,name是“flash_attn_2_cuda”的意思就是编译好的库怎么引用呢,就是通过import flash_attn_2_cuda来引用。
ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", sources=[ "csrc/flash_attn/flash_api.cpp", "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"] + generator_flag, "nvcc": append_nvcc_threads( [ "-O3", "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", # "--ptxas-options=-v", # "--ptxas-options=-O2", # "-lineinfo", # "-DFLASHATTENTION_DISABLE_BACKWARD", # "-DFLASHATTENTION_DISABLE_DROPOUT", # "-DFLASHATTENTION_DISABLE_ALIBI", # "-DFLASHATTENTION_DISABLE_SOFTCAP", # "-DFLASHATTENTION_DISABLE_UNEVEN_K", # "-DFLASHATTENTION_DISABLE_LOCAL", ] + generator_flag + cc_flag ), }, include_dirs=[ Path(this_dir) / "csrc" / "flash_attn", Path(this_dir) / "csrc" / "flash_attn" / "src", Path(this_dir) / "csrc" / "cutlass" / "include", ], ) ) ext_modules.append( CUDAExtension( name="flash_attn_2_cuda", sources=renamed_sources, extra_compile_args=extra_compile_args, include_dirs=include_dirs, ) )- 通过编译扩展,开发者可以利用C/C++的性能优势,同时保持python的易用性,这对于需要高性能计算的应用尤为重要
- 编译:
torch.utils.cpp_extension.CUDAExtension介绍
是pytorch提供的一个类,用于方便地构建和编译CUDA扩展。它封装了与CUDA相关的编译过程,允许用户在pytorch中轻松集成自定义的CUDA代码- 几个功能:
- 编译CUDA代码:允许用户指定CUDA源文件及相关的编译选项,从而生成可以在python中使用的共享库
- 集成C++代码:用户可以将C++代码与CUDA代码结合,创建复杂的扩展
- 简化配置:提供了一种简单的方法来管理编译过程中的各种设置,如头文件路径、库文件、编译器标志等
- 使用方法:
from torch.utils.cpp_extension import CUDAExtension, setup ext_modules = [ CUDAExtension( name='my_cuda_extension', # 模块名称 sources=['src/my_cuda_extension.cpp', # 源文件。即包含实际代码的文件,定义了要实现的功能或算法 'src/my_cuda_extension_kernel.cu'], include_dirs=['/path/to/include'], # 包含头文件的目录。包含了函数声明、宏定义和数据结构的定义。头文件使得不同源文件可以共享和复用代码 libraries=['mylib'], # 链接的库。是编译时需要引用的外部库,它们提供额外的功能,通常是在编译的过程中 # 与扩展模块进行链接。链接库可以是静态库(.a文件)或动态库(.so或.dll文件) library_dirs=['/path/to/lib'], # 库文件路径。指存放链接库的目录。当编译器在链接阶段寻找库文件时会使用这个路径 extra_compile_args={ "cxx": ["-O3", "-std=c++17"] + generator_flag, # -03:启用最高级别的优化,通常会生成更快但是编译时间更长的代码 # -std=c++17:指定使用C++17标准 # +generator_flag:追加其他生成器特定的编译选项。generator_flag通常是动态定义的,可能与编译器或构建工具有关 # 前面定义了 generator_flag = ["-DOLD_GENERATOR_PATH"] "nvcc": append_nvcc_threads( # 这里包含了为nvcc(nvidia CUDA编译器)指定的编译选项 [ "-O3", "-std=c++17", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", # 启用快速数学库以提高性能,但可能以牺牲准确性为代价 # "--ptxas-options=-v", # 编译时显示ptxas的详细信息,有助于调试 # "--ptxas-options=-O2", # "-lineinfo", # "-DFLASHATTENTION_DISABLE_BACKWARD", # "-DFLASHATTENTION_DISABLE_DROPOUT", # "-DFLASHATTENTION_DISABLE_ALIBI", # "-DFLASHATTENTION_DISABLE_SOFTCAP", # "-DFLASHATTENTION_DISABLE_UNEVEN_K", # "-DFLASHATTENTION_DISABLE_LOCAL", ] + generator_flag + cc_flag ), }, ) ]
- 几个功能:
- 所以,要看的就是
sources里的文件,这些就是要编译的CUDA源文件- 它们实现了不同版本的前向和反向传播算法:fp16/bf16、fwd/bwd、hdim、causal、split
flash_api.cpp:Flash Attention API 的定义和实现,用于提供 Python 和 CUDA 代码之间的接口。flash_fwd_hdimXX_fp16_sm80.cu:这些是 CUDA 源文件,涉及前向计算的实现,hdimXX 表示模型的隐藏维度(例如,32, 64, 96, 128, 160, 192, 256),fp16 指使用16位半精度浮点数(另外还有bf16),sm80 指该文件是为特定的 CUDA 架构(例如,80对应于 Ampere架构)编写的flash_fwd_hdimXX_fp16_causal_sm80.cu:这些文件是针对因果前向计算的实现(含掩码),适用于语言模型等需要因果注意力的任务。它们同样根据不同的隐藏维度和数据类型进行分类flash_bwd_hdimXX_fp16_sm80.cu:实现了backward反向传播的计算,用于训练过程中的梯度计算flash_bwd_hdimXX_fp16_causal_sm80.cu:实现了因果模型的反向传播flash_fwd_split_hdimXX_fp16_sm80.cu:实现了针对特定隐藏维度的分割前向计算,可能是为了更高效地处理大型输入(??)flash_fwd_split_hdimXX_fp16_causal_sm80.cu
- 所以改的话,就是改fwd、causal=false、(看下默认参数配置?
flash_api.cppset_params_fpropset_params_dgradrun_mha_fwdnum_splits_heuristicset_params_splitkvset_params_alibimha_fwdmha_varlen_fwdrun_mha_bwdmha_bwdmha_varlen_bwdmha_fwd_kvcache- pybind定义:
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "FlashAttention"; m.def("fwd", &mha_fwd, "Forward pass"); // 定义一个名为fwd的函数,绑定到上面的mha_fwd函数,并为该函数提供文档字符串“Forward pass”,这表示该函数实现了前向传播的计算逻辑 m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)"); m.def("bwd", &mha_bwd, "Backward pass"); m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)"); m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache"); }
- 总体调用流程:
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func- 编译好的
.so文件,使用其.fwd
从一个CUDA+Python联合调试的文章里清晰了解了一个CUDA项目的编译过程:
原始项目的目录树为:

其中,cuda_hello.cu是待调试的CUDA代码(里面定义了一个打印hello的核函数和一个主机端调用接口launch_cuda_hello);
pybind_wrapper.cpp使用pybind11这个库将CUDA代码中的主机调用接口函数注册到Python中(具体就是,先创建一个名为cuda_hello的python模块,然后将外部的主机函数launch_cuda_hello与新建python包中的函数名hello关联。最终在python中的使用方法就是:import cuda_hello,然后cuda_hello.hello())。如下,PYBIND11_MODULE是pybind11提供的宏,用于定义一个python模块,下面的代码中,模块名设为cuda_hello,并传入了m作为模块对象的引用,通过m为这个模块添加函数和类:
PYBIND11_MODULE(cuda_hello, m) { m.def("hello", &launch_cuda_hello, "A function that launches a CUDA kernel to print Hello"); }
在test_cuda_hello.py中,通过动态链接库导入cuda_hello这个包,并通过上述方法调用该包中的launch_cuda_hello函数
import cuda_hello
cuda_hello.hello()
在CMakeLists.txt文件中,设置CUDA标准、CUDA架构、C++ 标准等一系列配置,以及配置刚刚定义的编译源代码:查找pybind11包、添加CUDA源代码并创建共享库(add_library(cuda_functions SHARED src/cuda_hello.cu))、创建pybind11模块(pybind11_add_module(cuda_hello src/pybind_wrapper.cpp))、将CUDA函数库链接到pybind11模块(target_link_libraries(cuda_hello PRIVATE cuda_functions))。
即准备好pybind11->把cuda源文件打包成共享库->用pybind11创建一个python模块->将cuda共享库链接到python模块中,使python模块能执行GPU代码- 该
.fwd绑定的是flash_api.cpp中的mha_fwd函数 mha_fwd在完成初始化后,调用run_mha_fwd(params, stream)(依然定义在flash_api.cpp中)进行前向计算run_mha_fwd会根据 – 1)数据类型(params.is_bf16)、2)维度(params.d)、3)是否采用causal attention(params.is_causal) – 来调用run_mha_fwd_函数(或若force_split_kernel,调用run_mha_fwd_splitkv_dispatch函数)并传入elem_type、kHeadDim、Is_causal三个参数run_mha_fwd_函数声明在flash.h中(在flash_api.cpp中要include flash.h):template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
flash_fwd_launch_template.h介绍:通过宏定义和模板参数来生成不同变体的内核函数,从而适配不同的硬件架构、输入条件和操作模式- 包含头文件:主要涉及CUDA上下文、flash-attention计算
#include <ATen/cuda/CUDAContext.h>:是pytorch中的一个头文件,这个文件定义了与CUDA相关的上下文管理功能,主要用于处理CUDA设备的初始化、设备上下文切换以及流管理。ATen是pytorch的底层tensor库,提供了tensor计算、自动求导等基础功能CUDAContext:负责设备初始化、设备选择、与CUDA流有关的操作(CUDA流允许在GPU上并行执行多个任务)、以及与CUDA相关的资源管理- 该头文件是pytorch中实现GPU加速计算的关键部分
#include <ATen/cuda/CUDAContext.h> // 获取当前 CUDA 设备信息 int current_device = at::cuda::current_device(); // 切换 CUDA 设备 at::cuda::set_device(0); // 获取默认 CUDA 流 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#include "static_switch.h":通过一系列宏定义(如FP16_SWITCH、HEADDIM_SWITCH、BOOL_SWITCH)来简化和优化在编译时的条件分支处理。这些宏根据布尔或其他条件,在编译或运行时选择执行不同的代#include "flash.h":- 定义如下结构体:
Qkv_params、Flash_fwd_params、Flash_bwd_params - 定义如下函数模版:
run_mha_fwd_、run_mha_fwd_splitkv_dispatch、run_mha_bwd_
- 定义如下结构体:
#include "flash_fwd_kernel.h":主要就是进行attention的计算,且本头文件中定义的函数都放在namespace flash下面。具体定义如下函数:get_lse_tilecompute_attn:计算attention的外部逻辑函数,它会先获取块索引,然后调用compute_attn_1rowblock并将之前定义的参数和当前块索引传进去,进行实际的单行attention计算compute_attn_splitkv:和上面compute_attn的原理差不多,区别就是它支持split kv机制,能适应多头注意力的复杂需求,能通过分割逻辑优化性能compute_attn_1rowblock:用于计算单个行块(row block)上的attentioncompute_attn_1rowblock_splitkvcombine_attn_seqk_parallel:结合多个attention头的计算结果,以计算最终的输出
- 定义了三个核函数:
flash_fwd_kernel、flash_fwd_splitkv_kernel、flash_fwd_splitkv_combine_kernel。分别调用flash::compute_attn、flash::compute_attn_splitkv、flash::combine_attn_seqk_parallel进行attention的计算 - 定义了三个主机函数
run_flash_fwd、run_flash_splitkv_fwd、run_mha_fwd_splitkv_dispatch,分别调用了上面的flash_fwd_kernel、flash_fwd_splitkv_kernel、flash_fwd_splitkv_combine_kernel - 定义了不同维度的主机函数:
run_mha_fwd_hdim32、run_mha_fwd_hdim64、run_mha_fwd_hdim96、run_mha_fwd_hdim128、run_mha_fwd_hdim160、run_mha_fwd_hdim192、run_mha_fwd_hdim256,会调用run_flash_fwd - 也就是
flash_fwd_launch_template实际上是包裹了flash_fwd_kernel.h的实现(现在还未知是从外部的哪里调用了flash_fwd_launch_template,以及内部flash_fwd_kernel.h具体是如何实现的,如果没啥问题,应该就是改这个头文件了。但是有个疑问,就是他函数逻辑是定义在一个头文件里?)
- 包含头文件:主要涉及CUDA上下文、flash-attention计算
src中的包裹逻辑
flash_fwd_kernel.h:实现一行块一行块的attention
flash_fwd_launch_template.h:实现不同维度的run_mha_fwd_hdim256,进行run_flash_fwd函数的调用。run_flash_fwd再根据其他参数进行flash_fwd_kernel的调用,核函数flash_fwd_kernel会调用flash_fwd_kernel.h中的具体计算逻辑
具体在每个flash_fwd_hdim?_bf16?_sm80.cu文件中,会include上面的flash_fwd_launch_template.h,然后具体定义run_mha_fwd_函数:根据参数来调用具体的填满维度的函数,如run_mha_fwd_hdim96
最终在外部接口flash_api.cpp中,调用run_mha_fwd_函数
结论,所以改的话,只需要看flash_fwd_launch_template.h(每准这个也不用改)和flash_fwd_kernel.h即可。前者是分配了不同维度,后者是具体的计算
src中一些概念性定义的头文件
kernel_traits.h:定义了三个结构体struct Flash_kernel_traits:封装了不同CUDA架构的特性和操作,包括定义别名、定义MMA(矩阵乘法原子)、定义SmemCopyAtom和SmemCopyAtomTransposed(共享内存复制原子)struct Flash_fwd_kernel_traits : public Base:继承了上面的struct Flash_kernel_traits,并在前向计算中增加了特定的优化和数据布局方式。总的来说,这个结构体是对flash attention前向计算核函数的执行特性进行描述的,其描述了在GPU上计算attention时所设计的关键参数、内存布局和优化策略。结构体描述的内容包括:
说白了作用就是根据根据CUDA架构选择不同的内存布局、复制方式、核函数参数(如KNThreads、kBlockM等参数控制核函数执行时的线程数和块大小,确保核函数适合在不同的矩阵大小和head_dim下执行)和矩阵运算原子。- 线程和块大小:定义了核函数执行时的线程数、线程块大小、并行计算的warp数,这些参数决定了计算过程中每个线程处理的数据量等
- 内存布局和访问模式:描述了Q、K、V矩阵在shared memory和global memory中的布局方式(
SmemLayoutQ、SmemLayoutKV、GmemLayoutAtom等),通过这些布局来确保在GPU内存结构中高效读取和写入数据;同时使用特定的复制方式(SmemCopyAtom、GmemTiledCopyQKV)来减少共享内存的冲突和优化全局内存的带宽使用 - 架构优化:根据不同的硬件架构选择不同的优化策略,如是否使用
cp.async进行异步数据传输、根据是fp16还是bf16来选择不同的矩阵乘算法(MMA_Atom_Arch) - attention优化:如使用
kHeadDim定义了头部维度如何影响内存分配和复制方式,特别是在不同数据分块策略下,确保高效的矩阵乘法和内存操作
struct Flash_bwd_kernel_traits: public Base
flash_fwd_kernel.h 的具体实现
inline device void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidi, const int m_block)
Kernel_traitsflash_fwd_kernel.h:在模板函数中定义typename Kernel_traitsflash_fwd_launch_template.h-flash_fwd_kernel核函数:flash_fwd_kernel核函数是模板函数(该模板函数又是通过宏来定义的,即通过宏定义固定格式生成多个核函数,然后此flash_fwd_kernel核函数又通过自身模板函数的特性,可传入不同类型参数/不同参数值并在编译时就确定其值),其中就有typename Kernel_traits,进而在该核函数里通过调用flash_fwd_kernel.h中具体的attention计算函数来进行Kernel_traits的传递(传给上面)flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, ...>(params);flash_fwd_launch_template.h-run_flash_fwd主机函数:run_flash_fwd也是模板函数,定义了typename Kernel_traits,进而在该主机函数里通过调用上面的flash_fwd_kernel核函数来进行Kernel_traits的传递// run_flash_fwd函数的定义如下: template<typename Kernel_traits, bool Is_dropout, bool Is_causal> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream){...} // run_flash_fwd函数中具体调用上面核函数的部分代码如下: auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout&&!Is_softcap,...>; kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);flash_fwd_launch_template.h-run_mha_fwd_hdim?:以run_mha_fwd_hdim64为例,该函数会调用上面的run_flash_fwd函数:
这就找到了Kernel_traits了。根据上面constexpr static int Headdim = 64; run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);run_flash_fwd的函数定义可知,Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>就是具体传入的Kernel_traits。这是一个定义在kernel_traits.h中的结构体,在flash_fwd_launch_template.h中存在#include "flash_fwd_kernel.h",在flash_fwd_kernel.h中存在#include "kernel_traits.h",所以这里可以直接使用
- 具体该函数内执行q、k矩阵乘的部分:

然后又调用了这里
最后调用了cute::gemm,就是cutlass的实现了



















