Skip to content

FlashAttention-3

快速且精确的异步低精度注意力

FlashAttention-3FlashAttention 系列在 Hopper 架构 GPU 上的一次重写。它没有改掉注意力的数学定义,依然计算标准的精确注意力:

Attention(Q,K,V)=softmax(QKd)V

真正变化的是:

  • 数据怎样分块搬运;
  • 矩阵乘和 softmax 怎样穿插执行;
  • H100 / H800 这类 Hopper GPU 上,怎样把新的硬件能力真正吃满。

它在解决什么问题

标准注意力在实现上会遇到两个层面的成本。

第一层是中间矩阵太大。长度为 N 的序列会产生一个 N×N 的注意力分数矩阵,如果把它完整写回显存,再读出来做 softmax、再乘 V,显存读写会很重。

第二层是 GPU 真正慢的地方不全在矩阵乘。不仅仅是 QKPV 这两个矩阵乘,在现代 GPU 上,像 softmax 里的 exp、归一化、缩放、mask 这类逐元素操作,吞吐远低于 Tensor Core 上的矩阵乘。如果矩阵乘很快,而这些非矩阵乘操作跟不上,整体速度还是上不去。

FlashAttention 第一代和第二代已经把第一个问题解决了很多:它们通过分块和融合内核,避免把巨大的注意力矩阵完整写回 HBM,只保留每个 block 计算所需的局部信息,因此把额外显存开销从随序列长度平方增长压到线性增长。

FlashAttention-3 继续往前走,针对的是第二个问题:

Hopper 这种新硬件上,怎样让矩阵乘、数据搬运、softmax 这些环节尽量同时发生,而不是排队发生。

为什么要单独做 FlashAttention-3

FlashAttention-2A100 上已经很强,但 Hopper 这一代 GPU 新增了几类能力:

  • WGMMA:更适合 Hopper Tensor Core 的矩阵乘指令;
  • TMATensor Memory Accelerator,能更高效地把数据从全局显存搬到共享内存;
  • FP8:更低精度、更高吞吐的 Tensor Core 路径。

如果内核写法还是沿用旧思路,就很难把这些能力真正用起来。FlashAttention-3 论文指出,FlashAttention-2H100 上只能达到约 35% 的理论峰值利用率,而 FlashAttention-3 面向 Hopper 重新设计后,FP16 可达到最高约 740 TFLOPs/s,约为 75% 的理论峰值利用率;FP8 路径可接近 1.2 PFLOPs/s。官方博客与论文都给出了 1.5x2.0xFP16 提速范围。

先回顾 FlashAttention 家族保留下来的部分

这一代并没有推翻前两代最核心的思想。前两代留下来的骨架仍然成立:

  • 按 block 读取 Q,K,V
  • 在片上内存里完成分数计算、缩放、mask、softmax、乘 V
  • 用在线 softmax 维护每一行的最大值和归一化系数;
  • 不把完整的注意力矩阵 P 写回 HBM。

如果把某一行 softmax 的在线统计量记成:

mi=maxjsijli=jesijmi

那么 block 级 softmax 可以在不断看到新分块时更新 mili,而不必一次把整行所有分数都展开到显存里。FlashAttention-3 建立在这套思路上,只是把 kernel schedule 改得更适合 Hopper

Hopper 上真正吃满性能,难点在哪里

H100 上,矩阵乘的理论吞吐极高,而 softmax 里的指数函数、归一化等操作主要跑在别的执行单元上,吞吐远低于矩阵乘。Tri Dao 的官方博客给了一个很形象的数据:H100 SXM5FP16 矩阵乘峰值大约是 989 TFLOPS,但 special functions 的吞吐只有大约 3.9 TFLOPS。对 head dimension = 128 的注意力来说,虽然指数运算总 FLOPs 远少于矩阵乘,但因为 special function 单元慢得多,它们仍然可能占掉接近一半的时间;到了 FP8,矩阵乘更快,这个矛盾还会更明显。 博客

这也是 FlashAttention-3 的核心出发点:

  • 不是只让 QK 更快;
  • 而是让 QK、softmax、PV、数据搬运尽量重叠执行。

异步搬运与异步计算

Hopper 的一个变化是:数据搬运和矩阵乘不必再像早期写法那样严格串行。

TMA 负责更省资源的数据搬运

TMA 可以把 tile 从全局显存搬到共享内存,并把原本需要许多线程自己做的地址计算、越界处理等工作交给专门单元。这样做有两个直接结果:

  • 共享内存搬运更高效;
  • 寄存器压力下降,可以把 tile 做大,或者给其他计算留下更多寄存器空间。

WGMMA 提供更适合 Hopper 的矩阵乘路径

WGMMAHopper 上面向 Tensor Core 的新矩阵乘指令,比老一代路径更接近硬件峰值。FlashAttention-3 在内核里围绕 WGMMA 来组织 QK^\topPV 的大块矩阵乘。

Warp specialization

FlashAttention-3 不再让所有 warp 做同一种事,而是把工作拆成不同角色:

  • 一部分 warp 更偏向做生产者,负责借助 TMA 搬 tile;
  • 一部分 warp 更偏向做消费者,负责 WGMMA 矩阵乘;
  • softmax 和归一化计算再穿插到这些阶段之间。

这种写法的目标不是“让每一段都最快”,而是“让整个流水更少空转”。

矩阵乘和 softmax 不再排队

为什么不能把 softmax 当成边角料

注意力里至少有两类核心操作:

  • 矩阵乘:QKPV
  • 非矩阵乘:softmax 的缩放、减最大值、指数、求和、归一化。

Hopper 上,矩阵乘跑得极快。如果把整段流程写成:

  1. 先做完整 block 的矩阵乘;
  2. 再做 softmax;
  3. 再继续下一个 block;

那么 Tensor Core 和做 softmax 的执行单元很多时候是轮流忙碌,而不是同时忙碌。

Ping-pong 调度

FlashAttention-3 的一个调度思路是把不同 warpgroup 交替安排。可以把它想成两组工人轮流接棒:

  • 第 1 组在做这一轮的矩阵乘;
  • 第 2 组趁这个空隙处理另一轮的 softmax;
  • 然后再交换角色。

博客把这个过程称为 ping-pong scheduling。它带来的效果是:某一组在等矩阵乘结果时,另一组并不是闲着,而是在推进 softmax 和归一化。官方博客给出的数字是,FP16 forwardhead dim 128seq len 8K 下,采用这类调度可以把吞吐从约 570 TFLOPS 提到约 620 TFLOPS博客

同一个 warpgroup 内部也能再叠一层

除了不同 warpgroup 之间交错,FlashAttention-3 还在同一个 warpgroup 内部继续做细粒度穿插,让 softmax 的一部分计算和矩阵乘继续重叠。博客里给出的描述是,这样可以再把 FP16 forward 从大约 620 TFLOPS 提到 640–660 TFLOPS,代价是寄存器压力进一步上升。 博客

这里可以看到这一代设计的味道:

  • FlashAttention-1 更像先把 IO 问题打掉;
  • FlashAttention-2 更像做更好的并行划分;
  • FlashAttention-3 则明显开始进入“内核流水和硬件执行单元怎么彼此遮蔽”的层面。

FP8 路径与低精度误差

FP8 很吸引人,因为它能进一步提高 Tensor Core 吞吐,也常常意味着更省显存和带宽。但问题也很直接:位宽更低,量化误差更大,特别是激活里有离群值时更明显。

为什么离群值会让 FP8 更难用

如果某些通道值特别大,而绝大多数值比较小,那么量化时通常会为了容纳大值而牺牲小值的分辨率。注意力里的 QK 如果带有这类 outlier,计算 QK 时误差会被进一步放大。

Incoherent processing 在做什么

FlashAttention-3 的做法之一,是在量化前先对 QK 做一种“打散离群值”的变换,让大值不要总集中在少数坐标上。论文和博客把它称为 incoherent processing,实现上采用带随机符号的 Hadamard transform论文 博客

如果记原向量为 x,那么可以写成:

x~=HDx

其中:

  • D 是随机符号对角矩阵;
  • H 是 Hadamard 变换。

这样做不会改变“信息总量”,但会把极端值更均匀地摊开,使低精度量化更稳定。

为什么 Hadamard 适合放到内核里

Hadamard 变换可以用 O(dlogd) 时间完成,而且本身更偏内存带宽受限。博客提到,它还能和 RoPE 这类本就带有带宽压力的前处理阶段融合,因此额外成本相对可控。 博客

FP8 精度结果

论文报告说,带 FP8FlashAttention-3H100 上可接近 1.2 PFLOPs/s,同时相对某个基线 FP8 attention,数值误差降低约 2.6x。这说明它追求的不是“为了更快随便量化”,而是“在 Hopper 的 FP8 路径上把速度和误差一起管住”。 论文

数学对象没变,执行图变了

从模型作者的视角看,FlashAttention-3 仍然在做同一个注意力:

O=softmax(QKd)V

从 kernel 设计者的视角看,它已经不是一段“先算完这个,再算下一个”的顺序程序,而更像一个流水系统:

  • 下一块数据正在搬;
  • 当前块的矩阵乘正在 Tensor Core 上跑;
  • 上一块的 softmax 和归一化正在别的执行单元上跑;
  • 输出块正在被写回。

所以 FlashAttention-3 的价值不在“提出新公式”,而在“让同一个公式更像一条流水线”。

适用范围和工程限制

FlashAttention-3 并不是对所有 GPU 都一视同仁。当前官方代码把它明确放在 hopper 目录下,定位是面向 H100 / H800Hopper GPU。GitHub 仓库里,FlashAttention-3 仍被标为一个 beta release;当前公开说明写的是:

  • FP16 / BF16 支持 forward 和 backward;
  • FP8 当前公开的是 forward;
  • 需要 H100 / H800,并要求 CUDA >= 12.3,仓库还建议 CUDA 12.8 以获得更好性能。

这也意味着:

  • 如果你的机器是 A100,重点还是 FlashAttention-2
  • 如果你的机器是 H100FlashAttention-3 才更贴合硬件。

总结

FlashAttention-3 仍然计算标准的精确注意力,但它把 attention kernel 重新组织成了更贴合 Hopper GPU 的流水结构。

它继承了前两代的分块与在线 softmax 思路,继续避免把大注意力矩阵写回 HBM;同时利用 WGMMATMA、warp specialization、交错式矩阵乘与 softmax 调度,以及面向 FP8 的低误差处理,把 H100 上的实际利用率和吞吐进一步抬高。

论文报告它在 H100 上相对 FlashAttention-2 带来 1.5x–2.0xFP16 提速,FP16 可达 740 TFLOPs/sFP8 接近 1.2 PFLOPs/s,并把 FP8 数值误差压到基线的约 1/2.6