Skip to content

FlashAttention

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Dao-AILab/flash-attention

FlashAttention 是一种 精确注意力算法。它没有改变自注意力的数学定义,也不是近似注意力;它做的事情是:

重新安排注意力的计算顺序,让注意力在 GPU 上更符合内存层次结构,从而显著减少显存读写,提升速度并降低显存占用。

FlashAttention = 不把完整的 N×N 注意力矩阵落到慢显存里,而是分块计算、在线归一化、边算边聚合。

它最早由 Tri Dao 等人在 2022 年提出。论文强调,它的关键不是单纯减少浮点运算量,而是做 IO-aware 的注意力,也就是显式考虑 GPU 不同存储层之间的数据搬运成本。原论文还强调,FlashAttention 是 exact attention,不是近似方法。

标准注意力

给定查询矩阵 QRN×d、键矩阵 KRN×d、值矩阵 VRN×d,标准缩放点积注意力写成:

Attn(Q,K,V)=softmax(QKd)V

如果把中间量写出来,通常会先计算:

S=QKdRN×N

然后做行方向 softmax:

P=softmax(S)RN×N

最后输出:

O=PVRN×d

从数学上看,这没有问题;但从硬件实现上看,问题很大:

  • S 是一个 N×N 矩阵;
  • P 也是一个 N×N 矩阵;
  • 如果序列很长,这两个中间量会非常大;
  • 标准实现往往会把这些中间结果写入高带宽显存(HBM),再从 HBM 读回来继续算。

也就是说,标准注意力的瓶颈不只是算力,还包括:

为了算一次注意力,要反复把很大的中间矩阵在慢显存和计算单元之间来回搬运。

FlashAttention 解决的真正问题

很多人第一次接触 FlashAttention 时,会以为它在解决“自注意力计算量是二次复杂度”这个问题。这个理解只对了一半。

更准确地说,FlashAttention 解决的是:

标准注意力在 GPU 上会产生大量不必要的显存读写。

论文指出,很多近似注意力方法主要在减少 FLOPs,但在真实 GPU 上,壁钟时间不一定随 FLOPs 成比例下降;真正经常卡住的是 IO,也就是 HBM 和片上 SRAM 之间的数据搬运。FlashAttention 的核心就是让注意力计算更符合 GPU 的内存层次。

一个很容易弄错的点

FlashAttention 没有把精确注意力的算术复杂度从 O(N2) 变成线性。

它仍然是精确注意力,主要矩阵乘法量级仍然是 O(N2d)

它变快的关键,是 更少的 HBM 访问、更少的中间矩阵物化、更强的核函数融合

核心思想一:分块(tiling)

FlashAttention 的第一步是把 QKV 切成小块,而不是一次性把整个 S=QK/d 全算出来。

假设把:

  • Q 按行切成若干块 QiRBr×d
  • K,V 按行切成若干块 Kj,VjRBc×d

那么对某个块对 (i,j),只需要在片上 SRAM 中计算:

Sij=QiKjdRBr×Bc

然后立刻把这个小块对最终输出的贡献累加进去,而不是把完整的 N×N 矩阵存下来。

TIP

可以把 GPU 想成:

HBM:大仓库,容量大,但取东西慢;

SRAM:桌面,容量小,但拿东西很快;

Tensor Core / CUDA Core:真正干活的人。

标准 attention 经常是:

算出大矩阵,搬到仓库;再从仓库搬回来 softmax;再搬回仓库;再搬回来乘 V。

FlashAttention 的思路是:

别把大矩阵搬来搬去了,小块小块放在桌面上算完就走。

核心思想二:在线 softmax(online softmax)

分块以后,马上会遇到一个难点:

softmax 不是线性的。

如果整行分数被拆成很多块,我们不能简单地对每个块分别做 softmax,再把结果直接拼起来。因为真正的 softmax 归一化分母是整行一起算的:

softmax(sk)=eskjesj

那 FlashAttention 为什么还能分块而且保持 完全精确

关键就在于:

softmax 的“最大值”和“指数和”可以做在线聚合。

先看单行情形

假设某一行分数被分成两段:

s=[s(1),s(2)]

对第一段,定义:

m(1)=max(s(1))(1)=kesk(1)m(1)

对第二段,定义:

m(2)=max(s(2))(2)=kesk(2)m(2)

那么整个行的最大值是:

m=max(m(1),m(2))

整个行在以 m 为基准时的指数和就是:

=em(1)m(1)+em(2)m(2)

这说明:

  • 行最大值可以逐块更新;
  • softmax 的归一化分母也可以逐块更新;
  • 因此,softmax 可以在块级别上精确累计,而不需要一次看到整行所有元素。

这就是 FlashAttention 中“在线 softmax”的核心。

前向过程

先只看一个查询块 Qi

FlashAttention 会为这个块维护三类行级统计量:

  • 当前行最大值 miRBr
  • 当前归一化因子 iRBr
  • 当前输出累积量 O~iRBr×d

初始化:

mi=,i=0,O~i=0

然后依次遍历每个键值块 (Kj,Vj)

第一步:计算当前块的分数

Sij=QiKjd

第二步:求这个块自己的行最大值和指数和

m~ij=rowmax(Sij)P~ij=eSijm~ij~ij=rowsum(P~ij)

第三步:把当前块和历史统计量合并

新的行最大值:

minew=max(mi,m~ij)

新的归一化统计量:

inew=emiminewi+em~ijminew~ij

新的未归一化输出累积量:

O~inew=emiminewO~i+em~ijminewP~ijVj

其中 表示按行广播的逐元素乘法。

全部块处理完以后,再做一次行归一化:

Oi=O~ii

其中 表示按行除法。

本质上是在做一件事:

  • mi 始终维护“到当前为止整行分数的最大值”;
  • i 始终维护“以当前全局最大值为基准的指数和”;
  • O~i 始终维护“softmax 分子部分与 V 的加权和”。

因此最后得到的 Oi,和一次性算完整 softmax 再乘 V 的结果是一样的。

这就是 FlashAttention 的关键:

它是精确重排,不是近似替代。

为什么它更省显存

标准注意力最费显存的地方,是中间量:

  • S=QK,大小是 N×N
  • P=softmax(S),大小也是 N×N

而 FlashAttention 的做法是:

  • 不把完整 S 存到 HBM;
  • 不把完整 P 存到 HBM;
  • 只在片上 SRAM 中暂时处理当前块;
  • 只保存必要的小统计量和最终输出。

因此它的额外内存不再随 N2 增长,而可以做到对序列长度 N 线性增长。论文在定理 1 中明确指出,FlashAttention 的额外内存是 O(N),同时仍保持与标准精确注意力相同量级的计算复杂度 O(N2d)

FlashAttention 不是“少算了很多”,而是“少存了很多、少搬了很多”。

为什么它会更快

如果只看 FLOPs,FlashAttention 并没有把精确注意力从二次变成线性。

但在 GPU 上,注意力往往不是纯算力瓶颈,而是 内存带宽瓶颈。原论文把核心问题表述成:标准注意力会把大的 N×N 中间矩阵反复读写到 HBM,而 FlashAttention 通过分块和核函数融合显著减少了这种 HBM 访问。论文还分析了 HBM 访问复杂度:标准注意力需要 Θ(Nd+N2) 级别的 HBM 访问,而 FlashAttention 降为与 SRAM 大小 M 有关的更低量级。

因此,FlashAttention 的速度提升主要来自三件事:

  • 不物化完整注意力矩阵;
  • 尽量让计算在片上 SRAM 中完成;
  • 把原本分散的 matmul、mask、softmax、dropout、再 matmul 融成更少的核函数。

反向传播为什么也能省内存

前向时不保存完整的 SP,反向传播怎么办?

FlashAttention 的答案是:

重算(recomputation)

也就是说,前向传播不把巨大的中间矩阵存到 HBM 里;反向传播时,需要某个块的中间量,就把这个块重新在 SRAM 中算一遍。

乍看这像是在“多做事”,但在现代 GPU 上,这通常反而更划算。原因是:

  • 重新做一些局部计算,会增加少量 FLOPs;
  • 但避免了大量慢速 HBM 读写;
  • 而注意力常常本来就是内存受限,不是算力受限。

所以论文强调,即使由于重算带来额外 FLOPs,FlashAttention 仍然会更快。

FlashAttention 不是新的注意力公式

FlashAttention 不是像多头注意力、分组查询注意力、线性注意力那样,提出了一种新的注意力定义。

它的输出仍然是:

O=softmax(QKd)V

也就是说:

  • 从数学上,它和标准缩放点积注意力完全一样;
  • 从实现上,它把计算顺序和内存访问方式重新设计了一遍。

所以更准确地说,FlashAttention 是:

一种面向 GPU 内存层次优化的精确注意力实现算法。

它和长上下文有什么关系

FlashAttention 之所以对长上下文尤其重要,是因为序列越长,标准注意力的 N×N 中间矩阵越大,HBM 读写和显存占用就越离谱。

FlashAttention 虽然没有改变精确注意力对序列长度的二次算术依赖,但它把“能不能在显存里撑住”和“真实壁钟时间有多慢”这两件事大幅改善了。

因此它带来的实际效果通常是:

  • 在同样硬件下支持更长序列;
  • 或者在同样序列长度下更快;
  • 或者在同样显存预算下训练更大 batch;
  • 或者在推理时减少注意力成为瓶颈的程度。

这也是为什么后来的长上下文模型和高效推理库几乎都会优先集成 FlashAttention 类实现。

它和近似注意力的区别

在 FlashAttention 之前,很多“高效注意力”方法是通过近似来换速度,比如:

  • 只看局部窗口;
  • 做低秩近似;
  • 稀疏化注意力图;
  • 把 softmax attention 改写成线性注意力。

这些方法通常会在数学上改动注意力本身。

而 FlashAttention 不一样。它不近似,不删边,不改核函数,也不换成别的公式。它只是说:

同样的注意力公式,不能再按低效的方式实现了。

所以它的优势是:

  • 不牺牲精度;
  • 不改模型语义;
  • 常常可以直接替换现有精确注意力实现。

总结

FlashAttention = 用分块、在线 softmax 和重算,避免在 HBM 中物化完整注意力矩阵的精确注意力算法。

FlashAttention 的核心价值,不是提出了一个新的注意力定义,而是把标准精确注意力重新设计成了更符合 GPU 内存层次的实现。

它通过分块计算、在线 softmax、核函数融合和反向传播中的重算,避免了把大的 N×N 注意力矩阵写入慢显存,从而把额外内存从随序列长度平方增长压到线性增长,并显著减少显存读写。

原始 FlashAttention 的重点是 IO-aware 的精确注意力;FlashAttention-2 则在此基础上进一步优化并行划分和线程块内部工作分配,让实现更接近高效矩阵乘法的硬件利用率。

Online Softmax 的具体例子

假设某一行 attention score 是:

s=[1,2,3,4]

标准 softmax 是:

softmax(si)=esie1+e2+e3+e4

为了数值稳定,通常会减去最大值 4

softmax(si)=esi4e14+e24+e34+e44

也就是:

esi4e3+e2+e1+1

现在 FlashAttention 不想一次看完整行,而是分两块看:

[1,2][3,4]

第一块:

s(1)=[1,2]

这一块最大值是:

m(1)=2

这一块的指数和是:

(1)=e12+e22=e1+11.3679

第二块:

s(2)=[3,4]

这一块最大值是:

m(2)=4

这一块的指数和是:

(2)=e34+e44=e1+11.3679

现在要把两块合并。整行最大值是:

m=max(2,4)=4

但是第一块之前是按最大值 2 算的,现在全局最大值变成 4,所以第一块的指数和要重新缩放:

=em(1)m(1)+em(2)m(2)

代入:

=e241.3679+e441.3679

也就是:

=e21.3679+11.36791.5530

如果直接用全局最大值 4 算整行分母:

e14+e24+e34+e44=e3+e2+e1+11.5530

可以看到,分块合并和一次性计算得到的分母完全一样。

如果再带上 V,假设对应的值先简单看成标量:

v=[10,20,30,40]

标准 attention 输出是:

O=e1410+e2420+e3430+e4440e14+e24+e34+e44

FlashAttention 会维护一个未归一化输出:

O~=iesimvi

第一块按 m(1)=2

O~(1)=e1210+e222023.679

第二块按 m(2)=4

O~(2)=e3430+e444051.037

合并时同样要缩放:

O~=e24O~(1)+e44O~(2)

也就是:

O~e2×23.679+51.03754.241

最后除以分母:

O=O~=54.2411.55334.93

所以 online softmax 真正在维护三个东西:

  • m:当前见过的最大值;
  • :以当前最大值为基准的指数和,也就是 softmax 分母;
  • O~:以当前最大值为基准的加权值累积。

每来一个新块,就更新 mO~。最后再计算:

O=O~

这就是 FlashAttention 能够分块计算、但仍然保持精确的原因。