GPU 算力翻倍,AI 反而变慢了?FlashAttention-4 给出了惊人的答案
如果你最近在关注大模型训练可能会发现一个很反常的现象。GPU 一代比一代强。算力翻倍、Tensor Core 更快、AI 芯片越来越猛。但很多研究人员却发现模型训练速度并没有等比例提升。问题出在哪里答案其实藏在一个很多人忽略的地方Attention。而最近发布的一篇论文FlashAttention-4给出了一个非常有意思的答案。甚至可以说它重新定义了Attention 在 GPU 上应该如何实现。一、AI 世界最重要的一层Attention几乎所有现代 AI 模型——GPTClaudeGemini多模态模型底层都建立在Transformer架构上。而 Transformer 最核心的一步就是Attention。简单理解Attention 的作用就是让模型决定“哪些信息更重要”。例如一句话The animal didnt cross the street because it was too tired模型需要判断“it” 指的是谁是animal还是streetAttention 就是用来解决这种“上下文关联”的。但问题是Attention 的计算复杂度是O(N²)也就是说上下文越长计算量爆炸。这也是为什么很多模型4K8K32K100K token一旦上下文变长算力需求就会疯狂增长。二、FlashAttention改变游戏规则的优化为了加速 Attention研究界过去几年做了很多努力。其中最成功的方案之一就是FlashAttention。它的核心思想很简单减少 GPU 内存访问。因为在 GPU 上数据移动往往比计算更慢。FlashAttention 通过Tile 分块计算让 Attention 在SRAM 中完成计算避免频繁访问显存。结果非常惊人显存减少速度更快长序列训练更稳定于是FlashAttention 成为了很多 AI 框架的默认实现。但问题是——GPU 变了。三、新 GPU 出现了一个“奇怪现象”新一代 GPU ——NVIDIA Blackwell。带来了巨大的算力提升。相比上一代NVIDIA HopperTensor Core 的性能提升超过 2 倍。但很多研究人员跑 benchmark 时发现Attention 并没有变快那么多。于是他们做了一件事情分析 GPU 的瓶颈。结果非常意外。真正拖慢速度的不是矩阵乘法而是Softmax。更具体来说是两个操作指数运算 exp()shared memory 访问换句话说GPU 的矩阵计算已经快到一种程度——其它步骤反而成了拖后腿的。论文把这种现象称为Asymmetric Hardware Scaling四、FlashAttention-4 的核心思路面对这个问题研究人员做了一件非常大胆的事情重新设计 Attention 的 GPU 内核。不是简单优化。而是算法 硬件协同设计。FlashAttention-4 的优化主要集中在四个方向。1 用数学近似替代指数运算Softmax 中最贵的一步是exp(x)在 GPU 中这个操作由特殊单元执行但吞吐量很低研究人员想到一个办法不用 exp而是用多项式逼近。核心思想把指数拆成2^x 2^{整数} × 2^{小数}然后整数部分 → 位运算小数部分 → 多项式近似例如2^x ≈ a bx cx² dx³这些计算可以在FMA 单元完成。结果指数运算速度大幅提升。而在 BF16 精度下误差几乎可以忽略。2 重新设计 GPU 计算流水线FlashAttention-4 还重新设计了 GPU 的执行流水线。传统流程是矩阵乘法 → softmax → 输出而 FlashAttention-4 采用异步 pipeline当一部分数据在做矩阵乘法另一部分同时做 softmax做数据加载这种方式类似CPU 的超流水线执行GPU 利用率大幅提升。3 减少 Softmax 重缩放FlashAttention 使用一种叫online softmax的算法。它会频繁执行一个操作rescaleFlashAttention-4 的观察是其实只有在最大值变化时才需要 rescale。于是他们加了一层判断如果变化不大直接跳过 rescale。结果Softmax 的计算量再次减少。4 利用 GPU 新的 Tensor MemoryBlackwell GPU 引入了一个新的内存层Tensor Memory每个 SM 大约256KBFlashAttention-4 利用这个内存存储中间 Attention 结果。好处是减少 shared memory 访问降低寄存器压力支持更大的 tile这进一步提升了性能。五、性能提升有多大论文在NVIDIA B200GPU 上进行了测试。结果非常惊人。相比其他实现FlashAttention-4比cuDNN Attention 快 1.1 – 1.3 倍比Triton 实现 快 2 – 2.7 倍峰值算力1613 TFLOPS大约达到 GPU 理论算力的71%对于 GPU kernel 来说这是一个非常高的利用率。六、一个很多人没注意的改变FlashAttention-4 还有一个有意思的变化。它不再使用复杂的 C 模板。而是基于CuTe-DSL一个 Python DSL。优势非常明显Python 写 GPU kernel自动生成 PTXJIT 编译最关键的是编译时间从 55 秒降到 2.5 秒。研究人员可以更快测试新想法。七、这篇论文真正重要的地方FlashAttention-4 的意义其实不仅仅是一个优化。它揭示了一个趋势未来 AI 系统的性能瓶颈可能不再是矩阵计算。而是内存访问非线性函数调度也就是说AI 进入了一个新的阶段算力不是唯一瓶颈。如何设计更贴近硬件的算法会变得越来越重要。如果说最早的 FlashAttention 解决的是“Attention 太占显存”那么 FlashAttention-4 解决的是“GPU 太快了其它部分反而跟不上。”当 AI 硬件继续狂飙时这类算法 硬件协同设计的优化很可能会越来越重要。也许未来的大模型性能突破并不来自新的模型结构。而来自这些隐藏在底层的系统工程创新。更多transformerVITswin tranformer 参考头条号人工智能研究所 v号人工智能研究Suo, 启示AI科技动画详解transformer 在线视频教程
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2420931.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!