Skip to content

FlashAttention-2

FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

标准注意力的计算写成

S=QKT,P=softmax(S),O=PV

如果序列长度是 N,头维度是 d,那么分数矩阵 S 和概率矩阵 P 都是 N×N。这带来两个直接后果:

第一,计算量会随着序列长度平方增长。

第二,更麻烦的是中间张量很大,GPU 会频繁在高带宽显存(HBM)和片上存储之间搬运数据。很多时候,标准注意力慢,不只是因为算得多,而是因为“搬得太多”。

FlashAttention 的第一代工作正是围绕这一点展开:它不改变注意力公式本身,而是通过分块、在线 softmax 和重计算,把中间矩阵不再完整写回 HBM,在保证结果完全精确的前提下,把额外显存从二次规模降到线性规模。

FlashAttention-2 继承了这个方向,但把重点放在 GPU 并行和工作划分上,目标是让注意力更接近矩阵乘法内核的硬件效率。

为什么已经有 FlashAttention 还要做 FlashAttention-2

第一代 FlashAttention 已经带来了非常明显的收益:它是精确注意力,不做近似,额外显存随序列长度线性增长,而且在长序列上通常比常规实现更快。

但作者在 FlashAttention-2 里指出,第一代实现离 GPU 理论峰值还有明显距离。论文给出的分析是:FlashAttention 在 A100 上通常只达到大约 25% 到 40% 的理论最大 FLOPs/s,落后于优化得很好的矩阵乘法内核。

问题不再是“有没有省掉大矩阵”,而是更细的硬件层面问题:线程块之间的工作划分不够理想,导致有时占用率不高;同一个线程块内部,不同 warp 之间又会产生不必要的共享内存读写和同步。

FlashAttention-2 的工作,就是继续保持第一代的精确性和低显存占用,同时把这些硬件侧瓶颈再往下压。

FlashAttention

FlashAttention 的核心做法是把 QKV 按块搬到片上存储里,在片上完成一小块一小块的注意力计算,再把输出块更新回去。这样做的时候,不能把完整的 S=QKTP=softmax(S) 全部显式存下来,因此需要用在线 softmax 来逐块维护每一行的归一化统计量。

如果只看单个输出块,FlashAttention 的思路可以概括成:

  1. 固定一块查询 Qi
  2. 依次读取键值块 Kj,Vj
  3. 对每个块计算局部分数 Si(j)=QiKjT
  4. 用在线 softmax 更新该查询块对应的行最大值与归一化和
  5. 直接把这一块对输出 Oi 的贡献累计进去

这样,算法最终得到的仍然是

O=softmax(QKT)V

只是计算顺序被重新安排了。

FlashAttention-2 没有推翻这一套,而是在这个框架之上进一步优化。

减少非矩阵乘法操作

在现代 GPU 上,矩阵乘法通常由专门的计算单元加速,例如 Tensor Cores。矩阵乘法的吞吐非常高,而一些逐元素运算、缩放、归一化、额外的共享内存访问,则相对昂贵。FlashAttention-2 的一个出发点是:

既然矩阵乘法很快,就应该尽量把时间花在矩阵乘法上,而不是花在周边的小操作上。

论文给出的例子是,A100 上 FP16/BF16 矩阵乘法的理论吞吐可以到 312 TFLOPs/s,而普通 FP32 非矩阵乘法只有 19.5 TFLOPs/s。换句话说,非矩阵乘法的每一步都更“贵”。

FlashAttention-2 在前向里做了两处调整。

延后对输出的缩放

在在线 softmax 的块累积过程中,FlashAttention 会反复对输出做缩放。FlashAttention-2 改成维护一个未归一化的中间输出 O~,把归一化延后到循环末尾再做。

于是块级更新可以写成

O~i(j)=diag(emi(j1)mi(j))O~i(j1)+P~i(j)Vj

其中

P~i(j)=exp(Si(j)mi(j))

最后再统一做一次

Oi=diag(i(last))1O~i(last)

这样能减少循环体中的额外缩放开销。

只保存 logsumexp

为了反向传播,FlashAttention 会保留更多中间统计量。FlashAttention-2 指出,前向阶段没有必要同时保存行最大值 m 和指数和 ;保存

L=m+log

也就是 logsumexp,就足够支撑后向所需的信息。这减少了额外读写和存储压力。

这两处改动不改变输出,只是在算法层面减少了非矩阵乘法与额外的内存操作。

沿序列维增加并行度

FlashAttention 并行主要依赖 batch 维和头数维。如果 batch 很大、头很多,这样已经能把 GPU 填得比较满。但在很多长序列训练场景里,情况恰好相反:

  • 序列很长
  • batch 往往变小
  • 头数也未必多到足以把所有 SM 都占满

于是 GPU 会出现“还有空闲资源,但并没有足够多的独立工作可以派发”的问题,也就是占用率不够高。

FlashAttention-2 解决这个问题的方法,是把并行范围进一步扩展到序列维。

前向阶段的做法

前向里可以把不同的查询块分给不同线程块。因为这些查询块之间在当前阶段互不依赖,所以它们可以并行执行。论文把它描述为:沿序列长度维并行前向计算,让不同线程块分别负责注意力矩阵的不同“行块”。

如果写成块索引,前向中每个线程块负责一个或多个 Qi,然后遍历所有 Kj,Vj 来完成对应输出块 Oi 的更新。

后向阶段的做法

后向的依赖关系更复杂,但仍然可以沿序列维继续加并行。论文中给出的做法是:在后向里让线程块分担注意力矩阵的不同“列块”,并用原子加法来累积对 dQ 的更新。

这类改动的效果,在长序列、小 batch、小头数的情况下尤其明显,因为它直接增加了可以同时调度到 GPU 上的工作数。

在线程块内部重新划分 warp 的工作

除了线程块之间如何划分,线程块内部的 warp 该怎么分工,也会显著影响性能。

FlashAttention 在前向里采用的是一种可以称为“按键值切片”的方式:

  • 保持 Q 对所有 warp 可见
  • KV 切给不同 warp
  • 各个 warp 分别算出部分结果后,再通过共享内存汇总

这种做法的问题是:每个 warp 的中间结果都要写到共享内存,再同步,再做加总,通信成本不低。

FlashAttention-2 把分工方式调了过来:

  • 保持 KV 对所有 warp 可见
  • Q 在 warp 之间切分
  • 每个 warp 只负责自己那部分查询对应的输出

这样每个 warp 算完自己的那部分后,基本不需要再和别的 warp 进行结果归约。共享内存读写和同步次数明显减少。

FlashAttention 更像这样

  • 所有人一起看同一块 Q
  • 不同 warp 各自拿一部分 K,V
  • 最后需要把大家算出来的部分输出合并

FlashAttention-2 更像这样

  • 所有人一起看同一块 K,V
  • 不同 warp 各自负责不同的 Q
  • 每个 warp 直接产出自己的输出片段

第二种方式省掉了很多线程块内部的通信,这就是论文里强调的“减少共享内存读写”的来源。

FlashAttention-2 到底快了多少

论文报告,FlashAttention-2 相比 FlashAttention 大约有 2 倍左右的加速,在 A100 上前向吞吐可以达到理论峰值的 50% 到 73%,后向可达到理论峰值的 63%;用于端到端 GPT 风格训练时,单张 A100 可达到 225 TFLOPs/s 左右的训练速度。

FlashAttention-2 没有做什么

FlashAttention-2 没有把注意力改成线性复杂度方法,也没有引入稀疏模式、低秩近似或者核函数近似。它不是 Longformer、Performer、Linformer 那一类“换注意力定义”的方法。

它做的是:

  • 保持注意力结果完全不变
  • 重新组织计算
  • 让 GPU 更少搬数据
  • 让更多线程块和 warp 同时高效工作

因此它特别适合那些仍然需要标准注意力质量,但又希望把训练和推理做得更快的场景。

总结

FlashAttention-2 是在 FlashAttention 基础上的一次硬件效率升级。它没有改变注意力公式,也没有牺牲精度,而是围绕 GPU 执行效率做了三类优化:减少非矩阵乘法操作、沿序列维增加线程块并行、在线程块内部重新划分 warp 的工作,从而减少共享内存通信。结果是,在保持精确注意力和线性额外显存的前提下,它比第一代更接近高性能矩阵乘法内核的效率。