FlashAttention-2
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
标准注意力的计算写成
如果序列长度是
第一,计算量会随着序列长度平方增长。
第二,更麻烦的是中间张量很大,GPU 会频繁在高带宽显存(HBM)和片上存储之间搬运数据。很多时候,标准注意力慢,不只是因为算得多,而是因为“搬得太多”。
FlashAttention 的第一代工作正是围绕这一点展开:它不改变注意力公式本身,而是通过分块、在线 softmax 和重计算,把中间矩阵不再完整写回 HBM,在保证结果完全精确的前提下,把额外显存从二次规模降到线性规模。
FlashAttention-2 继承了这个方向,但把重点放在 GPU 并行和工作划分上,目标是让注意力更接近矩阵乘法内核的硬件效率。
为什么已经有 FlashAttention 还要做 FlashAttention-2
第一代 FlashAttention 已经带来了非常明显的收益:它是精确注意力,不做近似,额外显存随序列长度线性增长,而且在长序列上通常比常规实现更快。
但作者在 FlashAttention-2 里指出,第一代实现离 GPU 理论峰值还有明显距离。论文给出的分析是:FlashAttention 在 A100 上通常只达到大约 25% 到 40% 的理论最大 FLOPs/s,落后于优化得很好的矩阵乘法内核。
问题不再是“有没有省掉大矩阵”,而是更细的硬件层面问题:线程块之间的工作划分不够理想,导致有时占用率不高;同一个线程块内部,不同 warp 之间又会产生不必要的共享内存读写和同步。
FlashAttention-2 的工作,就是继续保持第一代的精确性和低显存占用,同时把这些硬件侧瓶颈再往下压。
FlashAttention
FlashAttention 的核心做法是把
如果只看单个输出块,FlashAttention 的思路可以概括成:
- 固定一块查询
- 依次读取键值块
- 对每个块计算局部分数
- 用在线 softmax 更新该查询块对应的行最大值与归一化和
- 直接把这一块对输出
的贡献累计进去
这样,算法最终得到的仍然是
只是计算顺序被重新安排了。
FlashAttention-2 没有推翻这一套,而是在这个框架之上进一步优化。
减少非矩阵乘法操作
在现代 GPU 上,矩阵乘法通常由专门的计算单元加速,例如 Tensor Cores。矩阵乘法的吞吐非常高,而一些逐元素运算、缩放、归一化、额外的共享内存访问,则相对昂贵。FlashAttention-2 的一个出发点是:
既然矩阵乘法很快,就应该尽量把时间花在矩阵乘法上,而不是花在周边的小操作上。
论文给出的例子是,A100 上 FP16/BF16 矩阵乘法的理论吞吐可以到 312 TFLOPs/s,而普通 FP32 非矩阵乘法只有 19.5 TFLOPs/s。换句话说,非矩阵乘法的每一步都更“贵”。
FlashAttention-2 在前向里做了两处调整。
延后对输出的缩放
在在线 softmax 的块累积过程中,FlashAttention 会反复对输出做缩放。FlashAttention-2 改成维护一个未归一化的中间输出
于是块级更新可以写成
其中
最后再统一做一次
这样能减少循环体中的额外缩放开销。
只保存 logsumexp
为了反向传播,FlashAttention 会保留更多中间统计量。FlashAttention-2 指出,前向阶段没有必要同时保存行最大值
也就是 logsumexp,就足够支撑后向所需的信息。这减少了额外读写和存储压力。
这两处改动不改变输出,只是在算法层面减少了非矩阵乘法与额外的内存操作。
沿序列维增加并行度
FlashAttention 并行主要依赖 batch 维和头数维。如果 batch 很大、头很多,这样已经能把 GPU 填得比较满。但在很多长序列训练场景里,情况恰好相反:
- 序列很长
- batch 往往变小
- 头数也未必多到足以把所有 SM 都占满
于是 GPU 会出现“还有空闲资源,但并没有足够多的独立工作可以派发”的问题,也就是占用率不够高。
FlashAttention-2 解决这个问题的方法,是把并行范围进一步扩展到序列维。
前向阶段的做法
前向里可以把不同的查询块分给不同线程块。因为这些查询块之间在当前阶段互不依赖,所以它们可以并行执行。论文把它描述为:沿序列长度维并行前向计算,让不同线程块分别负责注意力矩阵的不同“行块”。
如果写成块索引,前向中每个线程块负责一个或多个
后向阶段的做法
后向的依赖关系更复杂,但仍然可以沿序列维继续加并行。论文中给出的做法是:在后向里让线程块分担注意力矩阵的不同“列块”,并用原子加法来累积对
这类改动的效果,在长序列、小 batch、小头数的情况下尤其明显,因为它直接增加了可以同时调度到 GPU 上的工作数。
在线程块内部重新划分 warp 的工作
除了线程块之间如何划分,线程块内部的 warp 该怎么分工,也会显著影响性能。
FlashAttention 在前向里采用的是一种可以称为“按键值切片”的方式:
- 保持
对所有 warp 可见 - 把
和 切给不同 warp - 各个 warp 分别算出部分结果后,再通过共享内存汇总
这种做法的问题是:每个 warp 的中间结果都要写到共享内存,再同步,再做加总,通信成本不低。
FlashAttention-2 把分工方式调了过来:
- 保持
和 对所有 warp 可见 - 把
在 warp 之间切分 - 每个 warp 只负责自己那部分查询对应的输出
这样每个 warp 算完自己的那部分后,基本不需要再和别的 warp 进行结果归约。共享内存读写和同步次数明显减少。
FlashAttention 更像这样
- 所有人一起看同一块
- 不同 warp 各自拿一部分
- 最后需要把大家算出来的部分输出合并
FlashAttention-2 更像这样
- 所有人一起看同一块
- 不同 warp 各自负责不同的
- 每个 warp 直接产出自己的输出片段
第二种方式省掉了很多线程块内部的通信,这就是论文里强调的“减少共享内存读写”的来源。
FlashAttention-2 到底快了多少
论文报告,FlashAttention-2 相比 FlashAttention 大约有 2 倍左右的加速,在 A100 上前向吞吐可以达到理论峰值的 50% 到 73%,后向可达到理论峰值的 63%;用于端到端 GPT 风格训练时,单张 A100 可达到 225 TFLOPs/s 左右的训练速度。
FlashAttention-2 没有做什么
FlashAttention-2 没有把注意力改成线性复杂度方法,也没有引入稀疏模式、低秩近似或者核函数近似。它不是 Longformer、Performer、Linformer 那一类“换注意力定义”的方法。
它做的是:
- 保持注意力结果完全不变
- 重新组织计算
- 让 GPU 更少搬数据
- 让更多线程块和 warp 同时高效工作
因此它特别适合那些仍然需要标准注意力质量,但又希望把训练和推理做得更快的场景。
总结
FlashAttention-2 是在 FlashAttention 基础上的一次硬件效率升级。它没有改变注意力公式,也没有牺牲精度,而是围绕 GPU 执行效率做了三类优化:减少非矩阵乘法操作、沿序列维增加线程块并行、在线程块内部重新划分 warp 的工作,从而减少共享内存通信。结果是,在保持精确注意力和线性额外显存的前提下,它比第一代更接近高性能矩阵乘法内核的效率。