FlashAttention
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
FlashAttention 是一种 精确注意力算法。它没有改变自注意力的数学定义,也不是近似注意力;它做的事情是:
重新安排注意力的计算顺序,让注意力在 GPU 上更符合内存层次结构,从而显著减少显存读写,提升速度并降低显存占用。
FlashAttention = 不把完整的
注意力矩阵落到慢显存里,而是分块计算、在线归一化、边算边聚合。
它最早由 Tri Dao 等人在 2022 年提出。论文强调,它的关键不是单纯减少浮点运算量,而是做 IO-aware 的注意力,也就是显式考虑 GPU 不同存储层之间的数据搬运成本。原论文还强调,FlashAttention 是 exact attention,不是近似方法。
标准注意力
给定查询矩阵
如果把中间量写出来,通常会先计算:
然后做行方向 softmax:
最后输出:
从数学上看,这没有问题;但从硬件实现上看,问题很大:
是一个 矩阵; 也是一个 矩阵; - 如果序列很长,这两个中间量会非常大;
- 标准实现往往会把这些中间结果写入高带宽显存(HBM),再从 HBM 读回来继续算。
也就是说,标准注意力的瓶颈不只是算力,还包括:
为了算一次注意力,要反复把很大的中间矩阵在慢显存和计算单元之间来回搬运。
FlashAttention 解决的真正问题
很多人第一次接触 FlashAttention 时,会以为它在解决“自注意力计算量是二次复杂度”这个问题。这个理解只对了一半。
更准确地说,FlashAttention 解决的是:
标准注意力在 GPU 上会产生大量不必要的显存读写。
论文指出,很多近似注意力方法主要在减少 FLOPs,但在真实 GPU 上,壁钟时间不一定随 FLOPs 成比例下降;真正经常卡住的是 IO,也就是 HBM 和片上 SRAM 之间的数据搬运。FlashAttention 的核心就是让注意力计算更符合 GPU 的内存层次。
一个很容易弄错的点
FlashAttention 没有把精确注意力的算术复杂度从
它仍然是精确注意力,主要矩阵乘法量级仍然是
它变快的关键,是 更少的 HBM 访问、更少的中间矩阵物化、更强的核函数融合。
核心思想一:分块(tiling)
FlashAttention 的第一步是把
假设把:
按行切成若干块 ; 按行切成若干块 。
那么对某个块对
然后立刻把这个小块对最终输出的贡献累加进去,而不是把完整的
TIP
可以把 GPU 想成:
HBM:大仓库,容量大,但取东西慢;
SRAM:桌面,容量小,但拿东西很快;
Tensor Core / CUDA Core:真正干活的人。
标准 attention 经常是:
算出大矩阵,搬到仓库;再从仓库搬回来 softmax;再搬回仓库;再搬回来乘 V。
FlashAttention 的思路是:
别把大矩阵搬来搬去了,小块小块放在桌面上算完就走。
核心思想二:在线 softmax(online softmax)
分块以后,马上会遇到一个难点:
softmax 不是线性的。
如果整行分数被拆成很多块,我们不能简单地对每个块分别做 softmax,再把结果直接拼起来。因为真正的 softmax 归一化分母是整行一起算的:
那 FlashAttention 为什么还能分块而且保持 完全精确?
关键就在于:
softmax 的“最大值”和“指数和”可以做在线聚合。
先看单行情形
假设某一行分数被分成两段:
对第一段,定义:
对第二段,定义:
那么整个行的最大值是:
整个行在以
这说明:
- 行最大值可以逐块更新;
- softmax 的归一化分母也可以逐块更新;
- 因此,softmax 可以在块级别上精确累计,而不需要一次看到整行所有元素。
这就是 FlashAttention 中“在线 softmax”的核心。
前向过程
先只看一个查询块
FlashAttention 会为这个块维护三类行级统计量:
- 当前行最大值
; - 当前归一化因子
; - 当前输出累积量
。
初始化:
然后依次遍历每个键值块
第一步:计算当前块的分数
第二步:求这个块自己的行最大值和指数和
第三步:把当前块和历史统计量合并
新的行最大值:
新的归一化统计量:
新的未归一化输出累积量:
其中
全部块处理完以后,再做一次行归一化:
其中
本质上是在做一件事:
始终维护“到当前为止整行分数的最大值”; 始终维护“以当前全局最大值为基准的指数和”; 始终维护“softmax 分子部分与 的加权和”。
因此最后得到的
这就是 FlashAttention 的关键:
它是精确重排,不是近似替代。
为什么它更省显存
标准注意力最费显存的地方,是中间量:
,大小是 ; ,大小也是 。
而 FlashAttention 的做法是:
- 不把完整
存到 HBM; - 不把完整
存到 HBM; - 只在片上 SRAM 中暂时处理当前块;
- 只保存必要的小统计量和最终输出。
因此它的额外内存不再随
FlashAttention 不是“少算了很多”,而是“少存了很多、少搬了很多”。
为什么它会更快
如果只看 FLOPs,FlashAttention 并没有把精确注意力从二次变成线性。
但在 GPU 上,注意力往往不是纯算力瓶颈,而是 内存带宽瓶颈。原论文把核心问题表述成:标准注意力会把大的
因此,FlashAttention 的速度提升主要来自三件事:
- 不物化完整注意力矩阵;
- 尽量让计算在片上 SRAM 中完成;
- 把原本分散的 matmul、mask、softmax、dropout、再 matmul 融成更少的核函数。
反向传播为什么也能省内存
前向时不保存完整的
FlashAttention 的答案是:
重算(recomputation)。
也就是说,前向传播不把巨大的中间矩阵存到 HBM 里;反向传播时,需要某个块的中间量,就把这个块重新在 SRAM 中算一遍。
乍看这像是在“多做事”,但在现代 GPU 上,这通常反而更划算。原因是:
- 重新做一些局部计算,会增加少量 FLOPs;
- 但避免了大量慢速 HBM 读写;
- 而注意力常常本来就是内存受限,不是算力受限。
所以论文强调,即使由于重算带来额外 FLOPs,FlashAttention 仍然会更快。
FlashAttention 不是新的注意力公式
FlashAttention 不是像多头注意力、分组查询注意力、线性注意力那样,提出了一种新的注意力定义。
它的输出仍然是:
也就是说:
- 从数学上,它和标准缩放点积注意力完全一样;
- 从实现上,它把计算顺序和内存访问方式重新设计了一遍。
所以更准确地说,FlashAttention 是:
一种面向 GPU 内存层次优化的精确注意力实现算法。
它和长上下文有什么关系
FlashAttention 之所以对长上下文尤其重要,是因为序列越长,标准注意力的
FlashAttention 虽然没有改变精确注意力对序列长度的二次算术依赖,但它把“能不能在显存里撑住”和“真实壁钟时间有多慢”这两件事大幅改善了。
因此它带来的实际效果通常是:
- 在同样硬件下支持更长序列;
- 或者在同样序列长度下更快;
- 或者在同样显存预算下训练更大 batch;
- 或者在推理时减少注意力成为瓶颈的程度。
这也是为什么后来的长上下文模型和高效推理库几乎都会优先集成 FlashAttention 类实现。
它和近似注意力的区别
在 FlashAttention 之前,很多“高效注意力”方法是通过近似来换速度,比如:
- 只看局部窗口;
- 做低秩近似;
- 稀疏化注意力图;
- 把 softmax attention 改写成线性注意力。
这些方法通常会在数学上改动注意力本身。
而 FlashAttention 不一样。它不近似,不删边,不改核函数,也不换成别的公式。它只是说:
同样的注意力公式,不能再按低效的方式实现了。
所以它的优势是:
- 不牺牲精度;
- 不改模型语义;
- 常常可以直接替换现有精确注意力实现。
总结
FlashAttention = 用分块、在线 softmax 和重算,避免在 HBM 中物化完整注意力矩阵的精确注意力算法。
FlashAttention 的核心价值,不是提出了一个新的注意力定义,而是把标准精确注意力重新设计成了更符合 GPU 内存层次的实现。
它通过分块计算、在线 softmax、核函数融合和反向传播中的重算,避免了把大的
原始 FlashAttention 的重点是 IO-aware 的精确注意力;FlashAttention-2 则在此基础上进一步优化并行划分和线程块内部工作分配,让实现更接近高效矩阵乘法的硬件利用率。
Online Softmax 的具体例子
假设某一行 attention score 是:
标准 softmax 是:
为了数值稳定,通常会减去最大值
也就是:
现在 FlashAttention 不想一次看完整行,而是分两块看:
第一块:
这一块最大值是:
这一块的指数和是:
第二块:
这一块最大值是:
这一块的指数和是:
现在要把两块合并。整行最大值是:
但是第一块之前是按最大值
代入:
也就是:
如果直接用全局最大值
可以看到,分块合并和一次性计算得到的分母完全一样。
如果再带上
标准 attention 输出是:
FlashAttention 会维护一个未归一化输出:
第一块按
第二块按
合并时同样要缩放:
也就是:
最后除以分母:
所以 online softmax 真正在维护三个东西:
:当前见过的最大值; :以当前最大值为基准的指数和,也就是 softmax 分母; :以当前最大值为基准的加权值累积。
每来一个新块,就更新
这就是 FlashAttention 能够分块计算、但仍然保持精确的原因。