Skip to content

历史前缀中,到底什么值得缓存?

答案是:

  • 各层历史 token 的 K
  • 各层历史 token 的 V

这就是 KV Cache

但要真正理解它,不能只停在名字上,必须回答三个问题:

  1. 每层到底缓存什么对象
  2. 为什么这些对象能复用
  3. 为什么通常不缓存 Q

缓存什么与为什么只缓存KV

从 attention 公式出发

Attention(Q,K,V)=Softmax(QKTdk)V

对 decoder-only 模型的某一层来说,对于第 t 个 token,可以写成:

qt=htWQkt=htWKvt=htWV

这里:

  • ht 是第 t 个 token 进入该层后的表示
  • qt 是当前 token 发出的查询
  • kt 是当前 token 提供给未来位置检索的键
  • vt 是当前 token 携带的内容值

当模型生成第 t 个 token 时,它需要拿当前 token 的查询去访问历史所有位置:

Attn(qt,K1:t,V1:t)

也就是说,当前 token 会用自己的 qt 去和所有历史位置的 k1,k2,,kt 做匹配,再根据权重去聚合对应的 v1,v2,,vt

这已经暗示了一件事:

历史位置的 K/V 会在未来被反复使用。

KV Cache 到底缓存什么

在最基本的意义上,它保存的是:

  • 每一层
  • 每一个已经出现过的历史 token
  • 对应产生的 KV

也就是说,如果模型有 L 层,那么它通常会维护:

Cache(1),Cache(2),,Cache(L)

其中第 l 层的 cache 可以抽象成:

Cache(l)={K1:t(l),V1:t(l)}

这里:

  • K1:t(l) 表示第 l 层从第 1 个到第 t 个 token 的所有 key
  • V1:t(l) 表示第 l 层从第 1 个到第 t 个 token 的所有 value

所以 KV Cache 并不是只缓存“最后一层”的结果,而是:每一层都有自己的 K/V cache

这是因为:

  • 第 1 层生成的 K/V 只能给第 1 层的 attention 用
  • 第 2 层生成的 K/V 只能给第 2 层用
  • 各层的表示不同,不能混用

为什么历史 K/V 具有复用价值

假设我们已经处理过这些 token:

text
[A, B, C]

当生成下一个 token D 时,在某一层 attention 中:

  • D 会产生自己的 qD
  • 然后 qD 需要去查询历史的 kA,kB,kC
  • 再用权重聚合 vA,vB,vC

如果再生成下一个 token E,会发生什么?

  • E 会产生自己的 qE
  • qE 仍然需要去查询历史的 kA,kB,kC,kD
  • 再聚合 vA,vB,vC,vD

这意味着:

  • A/B/C/DK/V 会被后续 token 一次又一次访问
  • 它们是天然可复用的“历史索引与内容表征”

因此缓存历史 K/V 是非常合理的:

  • 历史 token 已经固定
  • 它们在某层产生的 K/V 也已经固定
  • 后续 token 只需要重复读取,而不需要重新生成

为什么通常不缓存 Q

通常不把 Q 作为长期 cache 的主体,原因是:

  • Q 是“当前步查询历史”的临时量
  • 它不像历史 K/V 那样会被未来很多步重复使用

在 attention 中:

  • Q 的角色是“我要找什么”
  • K 的角色是“我是什么线索”
  • V 的角色是“我携带什么内容”

当新 token 到来时:

  • 它会产生自己的 Q
  • 用这个 Q 去查询历史 K
  • 然后读取对应 V

这里最关键的是:

  • 当前 token 的 Q 只服务于当前这一步的查询
  • 历史 token 的 K/V 会服务于未来许多步

所以二者的复用价值完全不同。

KV Cache 在每一步是如何更新的

假设当前已经处理到第 t1 个 token。

l 层的 cache 中已经有:

K1:t1(l),V1:t1(l)

当第 t 个 token 到来时,会发生:

  1. 第一步

计算当前 token 在第 l 层的:

qt(l),kt(l),vt(l)
  1. 第二步

把新的 kt(l)vt(l) 追加到历史 cache:

K1:t(l)=concat(K1:t1(l),kt(l))V1:t(l)=concat(V1:t1(l),vt(l))
  1. 第三步

使用当前 qt(l) 去查询整个历史 cache:

Attn(qt(l),K1:t(l),V1:t(l))

这就完成了第 t 个 token 在该层的自注意力计算。

Prefill 阶段和 Decode 阶段下,KV Cache 的形成方式

在 prefill 阶段:

  • prompt 中每个 token 都会在每层产生自己的 K/V
  • 整段 prompt 处理完后,cache 里已经装好了 prompt 对应的历史 K/V

也就是说,prefill 完成后,你已经拥有:

  • “prompt 的 cache”

在 decode 阶段:

  • 每生成一个新 token,就只额外生成这个 token 的 K/V
  • 然后把它 append 到已有 cache 后面

所以可以把整个过程理解成:

  • prefill:一次性把初始仓库建起来
  • decode:每来一个新 token,就在仓库后面追加一小段

误区:KV Cache 只缓存 prompt 的 K/V

它会先缓存 prompt 的 K/V,然后在 decode 中继续追加新生成 token 的 K/V。

KV Cache 如何减少重复计算

但 KV Cache 并没有把 attention 变成常数复杂度

即使使用了 KV Cache,decode 时当前 token 仍然要:

  • 用自己的 Q 去和历史所有 K 做匹配
  • 再聚合历史所有 V

这意味着:

  • 当前步的开销仍然会随着历史长度增长

更直白地说:

  • KV Cache 消灭的是“历史 K/V 的重复生成”
  • 它没有消灭“当前 token 访问整段历史”的需求

KV Cache 的代价是什么

它把一部分重复计算,换成了长期保存历史 K/V 的内存开销。

也就是说:

  • 不用 cache:算得慢,但省掉一部分运行时历史存储
  • 用 cache:算得快,但要花显存保存历史 K/V

所以 KV Cache 的本质是一种典型的时间换空间:

  • 用额外内存换更快的 decode

误区:只要模型小,就不需要 KV Cache

只要是 decoder-only 自回归生成,历史重复计算就存在,只是模型越大、上下文越长,这个问题越突出。

误区:KV Cache 只对长 prompt 有用

它对长 prompt 很重要,但对长输出同样重要,因为 decode 会不断增长历史长度。

KV Cache 的张量形状、显存占用与位置处理

单层 KV Cache 的抽象形状

先从最经典的多头注意力开始。

设:

  • batch size 为 B
  • 注意力头数为 H
  • 当前序列长度为 T
  • 每个 head 的维度为 Dh

那么单层的 K/V cache 常常可以抽象为:

Kcache:[B,H,T,Dh]Vcache:[B,H,T,Dh]

这表示:

  • B 个样本
  • 每个样本有 H 个头
  • 每个头存了长度为 T 的历史序列
  • 每个位置对应一个 Dh 维向量

在不同框架里,维度顺序可能有所不同,例如也可能写成:[B,T,H,Dh] 或内部更复杂的布局。

但本质不变:

  • 批次
  • 序列长度
  • 每头维度

这几个维度总会出现。

为什么每一层都要存一份

如果模型有 L 层,那么通常不是只有一个总 cache,而是:

layer1:Kcache1,Vcache1layer2:Kcache2,Vcache2layerL:KcacheL,VcacheL

原因很简单:

  • 第 1 层的 K/V 来自第 1 层输入表示
  • 第 2 层的 K/V 来自第 2 层输入表示
  • 各层的表示空间不同

所以总 cache 占用会和层数线性相关。

Decode 时 cache 是如何增长的

假设某一层当前已有:

Kcache:[B,H,T,Dh]Vcache:[B,H,T,Dh]

现在新生成了 1 个 token。

这个 token 在该层会产生:

knew:[B,H,1,Dh]vnew:[B,H,1,Dh]

于是 cache 更新成:

Kcacheconcat(Kcache,knew,dim=seq)Vcacheconcat(Vcache,vnew,dim=seq)

更新后形状变成:[B,H,T+1,Dh],也就是说,decode 的 cache 是一个典型的“沿序列维追加”的结构。

KV Cache 的显存占用如何估算

对于标准多头注意力,KV Cache 的元素数量大致是:

2×L×B×T×H×Dh

这里前面的 2 表示:

  • 一份 K
  • 一份 V

如果每个元素占 bytes 个字节,那么总字节数近似为:

Bytes2×L×B×T×H×Dh×bytes_per_elem

这个公式直接告诉你:

  • 层数翻倍,cache 近似翻倍
  • batch 翻倍,cache 近似翻倍
  • 上下文长度翻倍,cache 近似翻倍
  • 头数翻倍,cache 近似翻倍

也就是说,KV Cache 是一个典型会随着多种维度同时线性膨胀的对象。

假设模型参数如下:

  • 层数 L=32
  • 注意力头数 H=32
  • 每头维度 Dh=128
  • batch size B=1
  • 序列长度 T=8192
  • 数据类型是 fp16,每个元素 2 字节

那么总占用近似为:

2×32×1×8192×32×128×2

计算后约为:

text
4,294,967,296 bytes ≈ 4 GiB

这只是:

  • 1 个请求
  • 8K 上下文

如果 batch size 变成 4,那么 cache 就会接近 16 GB。

而这还没算:

  • 模型权重
  • 激活
  • 临时工作区
  • 采样和调度的额外开销

所以你可以直观看到:

长上下文和高并发时,KV Cache 很容易成为显存主瓶颈。

为什么 prompt 长度和输出长度都会影响 cache

总序列长度 T 往往由两部分组成:

T=Tprompt+Tgenerated

这意味着:

  • prompt 越长,prefill 结束后初始 cache 就越大
  • 生成越长,decode 阶段 cache 还会继续增长

所以:

  • 长 prompt 会拖高起点
  • 长输出会让 cache 持续膨胀

两者都会推高显存占用。

Batch 对 KV Cache 的影响

在单请求视角下,cache 看起来已经不小。

一旦进入服务端场景,多请求 batch 进来时,情况会更明显。

因为 cache 大小近似正比于 B

也就是说:同样的模型,同样的上下文长度,batch size 越大,cache 越大,这就是为什么很多推理系统在追求吞吐时,总会面临一个权衡:大 batch 有利于吞吐,但大 batch 会迅速吃掉显存

这也解释了后面为什么需要:MQA / GQAKV 量化paged attention

位置编码为什么会和 KV Cache 绑在一起

decode 不是一次性把整段序列重新输入,而是:

  • 历史 token 已经在 cache 中
  • 当前只新增 1 个 token

于是模型必须明确知道:

  • 新 token 处于第几个位置
  • 历史 cache 中的每个 token 对应哪个位置

如果位置处理不一致,注意力就会错位。

绝对位置编码和 decode

对于最朴素的绝对位置编码:

  • 第 1 个 token 用位置 0
  • 第 2 个 token 用位置 1
  • 第 3 个 token 用位置 2

decode 时,当新 token 到来:

  • 它必须拿到新的 position id
  • 不能把它误当成历史某个旧位置

如果 prompt 长度已经是 P,那么第一个生成 token 的位置通常就是 P,第二个生成 token 的位置通常是 P+1,依此类推。这看起来很简单,但在真实实现中,如果 position id 管错了,输出会直接出问题。

RoPE

很多现代 LLM 并不用最原始的绝对位置编码,而更常使用:RoPE,即 Rotary Position Embedding。

从高层直觉上理解,RoPE 的特点是:

  • 把位置信息编码进 query 和 key
  • 让注意力更自然地表达相对位置信息

为什么很多实现要显式维护 position ids

在 decode 阶段,模型不再一次性看到整段序列,而通常只看到:

  • 当前新增 token
  • 以及历史 cache

所以框架必须显式知道:

  • 当前 token 是第几个位置
  • 是否因为滑动窗口或截断发生了位置偏移
  • 多请求混合时各自的位置进度是多少

这就是为什么很多推理代码里你会看到:

  • position_ids
  • cache_position
  • past_length

之类的变量。

它们不是实现噪声,而是在维护:历史 cache 与当前 token 的位置一致性

滑动窗口和 cache 截断

如果模型只支持有限上下文,比如最多 W 个 token,那么当序列继续增长时,可能会发生:

  • 只保留最近 W 个 token 的 cache
  • 更早的历史被丢弃

这就是所谓的滑动窗口思路。

此时实现会更复杂,因为你需要同时处理:

  • cache 截断
  • 注意力可见范围
  • 位置编号或相对位置解释

TIP

KV Cache 并不一定永远无限增长

在有限上下文模型里,它可能会被窗口化管理

KV Cache 的常见优化:MQA、GQA、KV 量化

如果标准多头注意力下的 KV Cache 太大,怎么办?

现代模型和推理系统里,最常见的继续优化方向有三类:

  1. 减少要缓存的 K/V head 数量
  2. 降低 cache 的数值精度
  3. 更聪明地组织和管理 cache

标准 MHA 为什么会有较大的 KV Cache

在标准多头注意力 MHA 中,通常:

  • query 有 Hq 个头
  • key 也有 Hq 个头
  • value 也有 Hq 个头

也就是说,Q/K/V 头数是一一对应的。

这意味着对每个 token、每一层来说,KV cache 大小大致与:

2×Hq×Dh

成正比。

于是上下文一长、层数一多、batch 一上来,cache 就会很大。

MQA:Multi-Query Attention

MQA 的核心思路非常直接:

  • 让 query 仍然保持多头
  • 但让所有 query 头共享同一组 key/value 头

也就是说:

  • Q 还是多头
  • K/V 变成极少头,最极端时就是 1 头

你可以把它理解成:

  • 多个 query 头从不同角度提问
  • 但大家都去查同一份 K/V 索引库

如果标准 MHA 是:

KV per token per layer2×Hq×Dh

那么 MQA 近似变成:

KV per token per layer2×1×Dh

也就是相对标准 MHA,cache 大小大约缩小为 1Hq

如果 Hq=32,那理论上就是非常显著的缩减。

MQA 的代价也很明显:

  • 所有 query 头共享同一份 K/V
  • 表达能力可能不如每个头都有独立 K/V 的标准 MHA

从直觉上看:

  • 你节省了大量 cache
  • 但也减少了 key/value 表示的多样性

所以它是一个典型的:

  • 速度和显存收益很大
  • 表达能力可能有一定折中

GQA:Grouped-Query Attention

GQA 可以理解成:

  • 不是像 MQA 那样让所有 query 头都共享一组 K/V
  • 而是把 query 头分成若干组
  • 每组共享一组 K/V

举例来说:

  • query 头数是 32
  • key/value 头数改成 8

那么每 4 个 query 头共享 1 组 K/V

因此:

  • 它比标准 MHA 的 cache 小
  • 又没有 MQA 那么极端

如果:

  • query 头数为 Hq
  • key/value 头数为 Hkv

那么 GQA 的 cache 大小大致与:

2×Hkv×Dh

成正比。

相对标准 MHA,缩减比例近似为:

HkvHq

例如:

  • Hq=32
  • Hkv=8

那么 cache 大约是原来的 14

很多人喜欢 GQA 就是因为它是一个非常自然的折中:比 MHA 更省内存,比 MQA 更保留表达能力。

机制Query 头数KV 头数cache 大小表达能力直觉
MHA最大最强基线
GQA较少中等较好折中
MQA1最小压缩最激进

KV 量化:另一条思路

除了减少 KV 头数,另一条自然路线是:

不改变 cache 的结构,但降低每个元素的存储精度。

这就是 KV 量化

如果你原来用的是:

  • fp16,每个元素约 2 字节

改成:

  • int8,每个元素约 1 字节

理论上 cache 占用就可以接近减半。

如果再往更低位走,例如更激进的量化,理论上还可以继续下降。

KV 量化 不改变:

  • 层数
  • 序列长度
  • 头数
  • 每头维度

它改变的是:

  • bytes_per_elem

也就是显存公式中的最后一项。

但量化从来不是白捡收益。

精度损失

精度越低,近似误差通常越大。

K/V 又直接参与 attention:

  • key 决定匹配权重
  • value 决定聚合内容

所以量化过猛可能影响生成质量。

反量化开销

很多实现不会让量化后的值永久直接参与所有计算,而会在一定步骤中做反量化或混合精度处理。

这意味着:

  • 你省了内存
  • 但也可能引入额外计算和实现复杂度

不同硬件友好度不同

某些硬件和 kernel 对特定精度格式更友好。

所以 KV 量化 的真实收益往往既取决于算法,也取决于实现栈。

什么时候优先考虑哪种优化

如果主要瓶颈是 cache 过大:GQA / MQAKV 量化

如果你担心表达能力损失太大:GQA

如果你不想动模型结构,只想省显存:KV 量化

现代推理系统里的 KV Cache:Paged Attention、Prompt Cache 与 Continuous Batching

真实服务里经常同时发生这些事:

  • 很多请求一起到来
  • 每个请求 prompt 长度不同
  • 每个请求生成长度不同
  • 有些请求共享相同前缀
  • 请求会动态进出 batch

这时“只知道有 KV Cache”还远远不够。

还必须理解:

  • 如何组织 cache 才不浪费内存
  • 如何复用共享前缀
  • 如何让不同请求一起高效解码

从单请求到多请求,问题为什么会突然变复杂

在单请求视角里,cache 似乎只是:一个随着时间增长的历史 K/V 仓库

但多请求一来,系统立刻会遇到更多约束:

长度不一致

请求 A 的序列长度可能是 300。

请求 B 的序列长度可能是 4000。

请求 C 的序列长度可能是 12000。

它们的 cache 大小完全不同。

生命周期不一致

有的请求刚开始 prefill。

有的请求已经在 decode 第 50 步。

有的请求下一步就会结束。

前缀可能重复

例如多个用户都在问同样的系统提示词或模板化 prompt:

text
You are a helpful assistant...

如果每个请求都重新构建一遍相同前缀的 cache,就会浪费很多计算和内存。

显存是离散分配的

不是说“想用多少就精确切多少”,系统通常要面对:

  • 内存块
  • 分配和回收
  • 碎片化

所以一旦进入服务层,真正的问题就变成:

如何让大量动态请求共享、复用并高效管理 KV Cache。

Paged Attention 想解决什么问题

先从最容易理解的痛点开始:

  • cache 很大
  • 请求长度不一致
  • 动态分配和回收会带来内存碎片

如果用最朴素的连续内存方式存每个请求的整段 cache,可能会出现:

  • 有些请求结束后,留下很多大小不一的空洞
  • 新请求来了,不一定能正好塞进这些空洞
  • 显存明明还有余量,但因为碎片化,实际很难高效利用

Paged Attention 的核心思想可以粗略理解成:

不把每个请求的 KV Cache 视为必须连续存放的一整段,而是切分成固定大小的块来管理。

这个思路和操作系统里“页”的概念很像,所以叫 paged

用“书页”类比 Paged Attention

把一个请求的历史 cache 想成一本书。

要求这本书的所有页必须连续摆在同一块大桌子上。

问题是:

  • 桌子空间会被零碎切开
  • 新书大小不一,很难总找到刚好连续的空位

Paged 方式改成:

  • 把书拆成固定大小的页块
  • 每一页可以放在不同位置
  • 系统维护“逻辑页号 -> 物理页块”的映射

所以 Paged Attention 的重点不是“改变注意力公式”,而是改变 cache 的内存组织方式

Paged Attention 本质上优化的是什么

它主要优化的是:

  • cache 的分配
  • cache 的回收
  • cache 的布局
  • 多请求下的内存利用率

它并不直接改变:

  • decoder-only 自回归生成的定义
  • KV Cache 的基本含义
  • attention 公式本身

所以其本质为:系统层 / 内存管理层优化

Prompt Cache / Prefix Cache 想解决什么问题

现在来看另一个常见优化:

  • prompt cache
  • 或者 prefix cache

它们关注的问题是:

如果多个请求有相同前缀,为什么还要每次都重新 prefill 一遍?

举个例子。

假设很多请求都有同样的系统提示词:

text
You are an expert coding assistant...

如果每次来新请求,都要对这段前缀做完整 prefill,那么会浪费:

  • 计算
  • cache 空间

于是一个自然思路是:

  • 先把这段前缀对应的历史 K/V 保留下来
  • 新请求如果前缀完全一样,就直接复用

Prompt Cache 的收益来自哪里

收益主要来自两个方面。

  1. 减少重复 prefill

相同 prompt 不必每次从头构建一遍历史 cache。

  1. 降低首 token 延迟

因为一部分 prefill 工作已经提前做过了。

所以它最典型的收益场景是:

  • 系统提示固定
  • 模板化 prompt 多
  • 多轮对话中前缀重复度高

Prompt Cache 的限制

它并不是“凡是像一点就能复用”。

最直接的限制是:

  • 前缀必须足够一致

在最保守的实现中,往往要求:

  • token 序列完全一致

如果前缀只差一个 token:

  • 之前的 prefix cache 可能就不能直接原样复用

此外,它还会遇到:

  • cache 淘汰策略
  • 多租户隔离
  • 一致性与命中率权衡

所以 prompt cache 非常有用,但也高度依赖实际业务分布。

Continuous Batching 想解决什么问题

接下来是服务层另一个非常重要的概念:

  • continuous batching

它关注的问题是:

请求不会整齐地同时开始和同时结束,系统怎么让 GPU 一直保持高利用率?

如果只做静态 batch,通常会很浪费:

  • 一批请求一起进来
  • 等到其中长请求拖很久时,短请求已经结束
  • 很多算力会因为等待和填充而被浪费

continuous batching 的核心思路是:

  • 不要求所有请求固定成一批到底
  • 请求可以动态加入和退出 batch
  • 每一轮 decode 都尽量把当前可执行的请求拼起来一起跑

这是一种调度方式上的优化。

为什么 Continuous Batching 和 KV Cache 紧密相关

因为动态 batch 运行时,系统必须持续管理:

  • 哪些请求还活着
  • 每个请求当前的序列长度是多少
  • 每个请求的 cache 在哪里
  • 哪些请求新加入,需要先 prefill
  • 哪些请求结束,可以释放 cache

所以 continuous batching 的难点并不只是“把请求凑一起算”,而是 如何在动态调度下管理各自的 KV Cache。这也是为什么现代高性能推理服务通常会把cache 管理batch 调度看成一个联动问题。

一个直观的多请求时间线

假设有三个请求:

  • 请求 A:先到,生成很长
  • 请求 B:后到,很快结束
  • 请求 C:再后到,中等长度

如果没有动态调度,系统可能要么:

  • 让 B 和 C 等 A
  • 要么反复重新组 batch,代价很高

而 continuous batching 更像:

  • 第 1 轮:只跑 A
  • 第 2 轮:A + B
  • 第 3 轮:A + B + C
  • 第 4 轮:A + C,因为 B 已结束

整个过程中:

  • 每个请求的 KV Cache 都在独立增长
  • 系统需要能灵活找到并访问它们各自的 cache

这时 paged attention 的内存组织优势就会变得非常重要。

Hugging Face Transformers 的缓存文档

Caching

想象一下,你正在和某人对话,但对方并不会记住之前说过的话,而是在你每次回应之后,都要重新从头开始。这显然会很慢,也很低效。Transformer 的自回归生成可以用同样的类比来理解。模型是一个 token 一个 token 往后预测的,而每一步新预测都依赖此前已经出现的全部上下文。要预测第 1000 个 token,模型需要前 999 个 token 里的信息;要预测第 1001 个 token,它又需要前 1000 个 token 的信息。这意味着同一批历史信息会被一遍又一遍地参与计算,其中大量开销都来自对 token 表示的矩阵运算。KV Cache 的目的,就是把已经处理过的 token 在注意力层里得到的 key-value 对保存起来,后面直接取出复用,从而避免重复计算。

Cache 只应该用于推理。
如果在训练时启用 cache,可能会引发意料之外的问题。

Attention matrices

要理解 cache 为什么可行,最好先看 attention 矩阵本身。对于 batch size 为 b、注意力头数为 h、当前序列长度为 T、每个 head 维度为 d_head 的情形,scaled dot-product attention 可以写成:

Attention(Q,K,V)=Softmax(QKdhead+mask)V

其中,query、key、value 都是从输入 embedding 投影得到的,它们的常见形状为:

(b,h,T,dhead)

在 causal attention 里,mask 会阻止模型看见未来 token。于是,一个 token 一旦被处理完,它在面对未来 token 时的表示就不再发生变化,过去 token 对应的 KV 也就可以缓存下来,在后续步骤里继续使用。如果当前正在处理第 t 个 token,那么这一时刻的 attention 可以理解为:

Attention(qt,[k1,k2,,kt],[v1,v2,,vt])

在推理阶段,每一步真正新增的是当前 token 对应的那一小部分信息。新产生的 key 和 value 会被写入 cache,并接到历史 key/value 后面:

Kcacheconcat(Kpast,kt)Vcacheconcat(Vpast,vt)

attention 在模型的每一层里都是独立计算的,所以 cache 也是逐层维护的。不使用 cache 时,每一步都要重新得到此前 token 的 K/V;使用 cache 时,每一步通常只需要为当前 token 计算新的 K/V,历史部分直接复用。这样做的结果是,随着序列变长,推理里的重复工作会显著减少。

without cachingwith caching
每一步都重新计算此前所有 token 的 KV每一步只计算当前 token 的 KV
每步 attention 的代价会随着序列长度快速上升每步计算更容易控制,虽然内存仍会随长度增长

Cache class

从接口角度看,一个最基础的 KV cache 会接收当前 token 的 key 和 value,然后返回更新后的 K/V。这件事通常由模型的 forward 方法在内部处理。

python
new_K, new_V = cache.update(k_t, v_t, layer_idx)
attn_output = attn_layer_idx_fn(q_t, new_K, new_V)

当使用 Transformers 的 Cache 类时,自注意力模块会把过去和现在的信息拼到一起参与计算。当前 token 的 kv 会和 cache 中已经保存的历史 kv 拼接,这样 attention 分数就能同时基于旧上下文和新输入来计算。如果你是反复调用 forward 自己写 generation loop,那么还需要特别注意 attention mask 的长度必须和“历史 kv + 当前 kv”拼接后的总长度一致。generate() 通常会替你处理这些细节,但手写循环时就需要自己保证 mask 同时覆盖过去 token 和当前 token。

Cache storage implementation

在实现上,cache 可以理解成一个按层组织的结构,每一层都有自己的 key cache 和 value cache,而它们的常见形状为:

[batch_size,num_heads,seq_len,head_dim]

不同层可以使用不同类型的 cache layer,例如 DynamicLayerStaticLayerStaticSlidingWindowLayer。它们最主要的差别,在于序列长度如何管理,以及新 token 到来时 cache 怎样更新。最容易理解的是 DynamicLayer,它会随着新 token 的到来不断增长,也就是说 seq_len 会持续增加,更新方式可以理解成沿着序列维不断追加:

python
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)

StaticLayerStaticSlidingWindowLayer 则会在创建 cache 时就把长度固定下来,因此更容易和 torch.compile 这类静态形状优化配合使用。对于 StaticSlidingWindowLayer,当新 token 进入而窗口又已经满了时,较早的 token 会被移出 cache。

下面是一个用 DynamicCache 手写生成循环的例子。关键点不在某个具体模型,而在调用方式本身:第一次把整段输入送进模型,之后每一轮通常只送入尚未处理的新 token,同时持续传入 past_key_values,并扩展 attention_mask,让它始终覆盖历史 token 和当前新增 token。

python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from accelerate import Accelerator

device = Accelerator().device

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache(config=model.config)
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
).to(model.device)

generated_ids = inputs.input_ids
max_new_tokens = 10

for _ in range(max_new_tokens):
    outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
    next_token_ids = outputs.logits[:, -1:].argmax(-1)
    generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)

    attention_mask = inputs["attention_mask"]
    attention_mask = torch.cat(
        [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
        dim=-1,
    )
    inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}