历史前缀中,到底什么值得缓存?
答案是:
- 各层历史 token 的
K - 各层历史 token 的
V
这就是 KV Cache。
但要真正理解它,不能只停在名字上,必须回答三个问题:
- 每层到底缓存什么对象
- 为什么这些对象能复用
- 为什么通常不缓存
Q
缓存什么与为什么只缓存KV
从 attention 公式出发
对 decoder-only 模型的某一层来说,对于第
这里:
是第 个 token 进入该层后的表示 是当前 token 发出的查询 是当前 token 提供给未来位置检索的键 是当前 token 携带的内容值
当模型生成第
也就是说,当前 token 会用自己的
这已经暗示了一件事:
历史位置的
K/V会在未来被反复使用。
KV Cache 到底缓存什么
在最基本的意义上,它保存的是:
- 每一层
- 每一个已经出现过的历史 token
- 对应产生的
K和V
也就是说,如果模型有
其中第
这里:
表示第 层从第 1 个到第 个 token 的所有 key 表示第 层从第 1 个到第 个 token 的所有 value
所以 KV Cache 并不是只缓存“最后一层”的结果,而是:每一层都有自己的 K/V cache
这是因为:
- 第 1 层生成的 K/V 只能给第 1 层的 attention 用
- 第 2 层生成的 K/V 只能给第 2 层用
- 各层的表示不同,不能混用
为什么历史 K/V 具有复用价值
假设我们已经处理过这些 token:
[A, B, C]当生成下一个 token D 时,在某一层 attention 中:
D会产生自己的- 然后
需要去查询历史的 - 再用权重聚合
如果再生成下一个 token E,会发生什么?
E会产生自己的仍然需要去查询历史的 - 再聚合
这意味着:
A/B/C/D的K/V会被后续 token 一次又一次访问- 它们是天然可复用的“历史索引与内容表征”
因此缓存历史 K/V 是非常合理的:
- 历史 token 已经固定
- 它们在某层产生的 K/V 也已经固定
- 后续 token 只需要重复读取,而不需要重新生成
为什么通常不缓存 Q
通常不把 Q 作为长期 cache 的主体,原因是:
Q是“当前步查询历史”的临时量- 它不像历史
K/V那样会被未来很多步重复使用
在 attention 中:
Q的角色是“我要找什么”K的角色是“我是什么线索”V的角色是“我携带什么内容”
当新 token 到来时:
- 它会产生自己的
Q - 用这个
Q去查询历史K - 然后读取对应
V
这里最关键的是:
- 当前 token 的
Q只服务于当前这一步的查询 - 历史 token 的
K/V会服务于未来许多步
所以二者的复用价值完全不同。
KV Cache 在每一步是如何更新的
假设当前已经处理到第
第
当第
- 第一步
计算当前 token 在第
- 第二步
把新的
- 第三步
使用当前
这就完成了第
Prefill 阶段和 Decode 阶段下,KV Cache 的形成方式
在 prefill 阶段:
- prompt 中每个 token 都会在每层产生自己的
K/V - 整段 prompt 处理完后,cache 里已经装好了 prompt 对应的历史
K/V
也就是说,prefill 完成后,你已经拥有:
- “prompt 的 cache”
在 decode 阶段:
- 每生成一个新 token,就只额外生成这个 token 的
K/V - 然后把它 append 到已有 cache 后面
所以可以把整个过程理解成:
- prefill:一次性把初始仓库建起来
- decode:每来一个新 token,就在仓库后面追加一小段
误区:KV Cache 只缓存 prompt 的 K/V
它会先缓存 prompt 的 K/V,然后在 decode 中继续追加新生成 token 的 K/V。
KV Cache 如何减少重复计算
但 KV Cache 并没有把 attention 变成常数复杂度
即使使用了 KV Cache,decode 时当前 token 仍然要:
- 用自己的
Q去和历史所有K做匹配 - 再聚合历史所有
V
这意味着:
- 当前步的开销仍然会随着历史长度增长
更直白地说:
KV Cache消灭的是“历史 K/V 的重复生成”- 它没有消灭“当前 token 访问整段历史”的需求
KV Cache 的代价是什么
它把一部分重复计算,换成了长期保存历史 K/V 的内存开销。
也就是说:
- 不用 cache:算得慢,但省掉一部分运行时历史存储
- 用 cache:算得快,但要花显存保存历史 K/V
所以 KV Cache 的本质是一种典型的时间换空间:
- 用额外内存换更快的 decode
误区:只要模型小,就不需要 KV Cache
只要是 decoder-only 自回归生成,历史重复计算就存在,只是模型越大、上下文越长,这个问题越突出。
误区:KV Cache 只对长 prompt 有用
它对长 prompt 很重要,但对长输出同样重要,因为 decode 会不断增长历史长度。
KV Cache 的张量形状、显存占用与位置处理
单层 KV Cache 的抽象形状
先从最经典的多头注意力开始。
设:
- batch size 为
- 注意力头数为
- 当前序列长度为
- 每个 head 的维度为
那么单层的 K/V cache 常常可以抽象为:
这表示:
- 有
个样本 - 每个样本有
个头 - 每个头存了长度为
的历史序列 - 每个位置对应一个
维向量
在不同框架里,维度顺序可能有所不同,例如也可能写成:
但本质不变:
- 层
- 批次
- 头
- 序列长度
- 每头维度
这几个维度总会出现。
为什么每一层都要存一份
如果模型有
原因很简单:
- 第 1 层的
K/V来自第 1 层输入表示 - 第 2 层的
K/V来自第 2 层输入表示 - 各层的表示空间不同
所以总 cache 占用会和层数线性相关。
Decode 时 cache 是如何增长的
假设某一层当前已有:
现在新生成了 1 个 token。
这个 token 在该层会产生:
于是 cache 更新成:
更新后形状变成:
KV Cache 的显存占用如何估算
对于标准多头注意力,KV Cache 的元素数量大致是:
这里前面的 2 表示:
- 一份
K - 一份
V
如果每个元素占 bytes 个字节,那么总字节数近似为:
这个公式直接告诉你:
- 层数翻倍,cache 近似翻倍
- batch 翻倍,cache 近似翻倍
- 上下文长度翻倍,cache 近似翻倍
- 头数翻倍,cache 近似翻倍
也就是说,KV Cache 是一个典型会随着多种维度同时线性膨胀的对象。
假设模型参数如下:
- 层数
- 注意力头数
- 每头维度
- batch size
- 序列长度
- 数据类型是
fp16,每个元素 2 字节
那么总占用近似为:
计算后约为:
4,294,967,296 bytes ≈ 4 GiB这只是:
- 1 个请求
- 8K 上下文
如果 batch size 变成 4,那么 cache 就会接近 16 GB。
而这还没算:
- 模型权重
- 激活
- 临时工作区
- 采样和调度的额外开销
所以你可以直观看到:
长上下文和高并发时,KV Cache 很容易成为显存主瓶颈。
为什么 prompt 长度和输出长度都会影响 cache
总序列长度
这意味着:
- prompt 越长,prefill 结束后初始 cache 就越大
- 生成越长,decode 阶段 cache 还会继续增长
所以:
- 长 prompt 会拖高起点
- 长输出会让 cache 持续膨胀
两者都会推高显存占用。
Batch 对 KV Cache 的影响
在单请求视角下,cache 看起来已经不小。
一旦进入服务端场景,多请求 batch 进来时,情况会更明显。
因为 cache 大小近似正比于
也就是说:同样的模型,同样的上下文长度,batch size 越大,cache 越大,这就是为什么很多推理系统在追求吞吐时,总会面临一个权衡:大 batch 有利于吞吐,但大 batch 会迅速吃掉显存
这也解释了后面为什么需要:MQA / GQA、KV 量化、paged attention
位置编码为什么会和 KV Cache 绑在一起
decode 不是一次性把整段序列重新输入,而是:
- 历史 token 已经在 cache 中
- 当前只新增 1 个 token
于是模型必须明确知道:
- 新 token 处于第几个位置
- 历史 cache 中的每个 token 对应哪个位置
如果位置处理不一致,注意力就会错位。
绝对位置编码和 decode
对于最朴素的绝对位置编码:
- 第 1 个 token 用位置 0
- 第 2 个 token 用位置 1
- 第 3 个 token 用位置 2
decode 时,当新 token 到来:
- 它必须拿到新的 position id
- 不能把它误当成历史某个旧位置
如果 prompt 长度已经是
RoPE
很多现代 LLM 并不用最原始的绝对位置编码,而更常使用:RoPE,即 Rotary Position Embedding。
从高层直觉上理解,RoPE 的特点是:
- 把位置信息编码进 query 和 key
- 让注意力更自然地表达相对位置信息
为什么很多实现要显式维护 position ids
在 decode 阶段,模型不再一次性看到整段序列,而通常只看到:
- 当前新增 token
- 以及历史 cache
所以框架必须显式知道:
- 当前 token 是第几个位置
- 是否因为滑动窗口或截断发生了位置偏移
- 多请求混合时各自的位置进度是多少
这就是为什么很多推理代码里你会看到:
position_idscache_positionpast_length
之类的变量。
它们不是实现噪声,而是在维护:历史 cache 与当前 token 的位置一致性
滑动窗口和 cache 截断
如果模型只支持有限上下文,比如最多
- 只保留最近
个 token 的 cache - 更早的历史被丢弃
这就是所谓的滑动窗口思路。
此时实现会更复杂,因为你需要同时处理:
- cache 截断
- 注意力可见范围
- 位置编号或相对位置解释
TIP
KV Cache 并不一定永远无限增长
在有限上下文模型里,它可能会被窗口化管理
KV Cache 的常见优化:MQA、GQA、KV 量化
如果标准多头注意力下的 KV Cache 太大,怎么办?
现代模型和推理系统里,最常见的继续优化方向有三类:
- 减少要缓存的
K/V head数量 - 降低 cache 的数值精度
- 更聪明地组织和管理 cache
标准 MHA 为什么会有较大的 KV Cache
在标准多头注意力 MHA 中,通常:
- query 有
个头 - key 也有
个头 - value 也有
个头
也就是说,Q/K/V 头数是一一对应的。
这意味着对每个 token、每一层来说,KV cache 大小大致与:
成正比。
于是上下文一长、层数一多、batch 一上来,cache 就会很大。
MQA:Multi-Query Attention
MQA 的核心思路非常直接:
- 让 query 仍然保持多头
- 但让所有 query 头共享同一组 key/value 头
也就是说:
Q还是多头K/V变成极少头,最极端时就是 1 头
你可以把它理解成:
- 多个 query 头从不同角度提问
- 但大家都去查同一份
K/V索引库
如果标准 MHA 是:
那么 MQA 近似变成:
也就是相对标准 MHA,cache 大小大约缩小为
如果
MQA 的代价也很明显:
- 所有 query 头共享同一份
K/V - 表达能力可能不如每个头都有独立
K/V的标准 MHA
从直觉上看:
- 你节省了大量 cache
- 但也减少了 key/value 表示的多样性
所以它是一个典型的:
- 速度和显存收益很大
- 表达能力可能有一定折中
GQA:Grouped-Query Attention
GQA 可以理解成:
- 不是像 MQA 那样让所有 query 头都共享一组
K/V - 而是把 query 头分成若干组
- 每组共享一组
K/V
举例来说:
- query 头数是 32
- key/value 头数改成 8
那么每 4 个 query 头共享 1 组 K/V。
因此:
- 它比标准 MHA 的 cache 小
- 又没有 MQA 那么极端
如果:
- query 头数为
- key/value 头数为
那么 GQA 的 cache 大小大致与:
成正比。
相对标准 MHA,缩减比例近似为:
例如:
那么 cache 大约是原来的
很多人喜欢 GQA 就是因为它是一个非常自然的折中:比 MHA 更省内存,比 MQA 更保留表达能力。
| 机制 | Query 头数 | KV 头数 | cache 大小 | 表达能力直觉 |
|---|---|---|---|---|
| MHA | 多 | 多 | 最大 | 最强基线 |
| GQA | 多 | 较少 | 中等 | 较好折中 |
| MQA | 多 | 1 | 最小 | 压缩最激进 |
KV 量化:另一条思路
除了减少 KV 头数,另一条自然路线是:
不改变 cache 的结构,但降低每个元素的存储精度。
这就是 KV 量化。
如果你原来用的是:
fp16,每个元素约 2 字节
改成:
int8,每个元素约 1 字节
理论上 cache 占用就可以接近减半。
如果再往更低位走,例如更激进的量化,理论上还可以继续下降。
KV 量化 不改变:
- 层数
- 序列长度
- 头数
- 每头维度
它改变的是:
bytes_per_elem
也就是显存公式中的最后一项。
但量化从来不是白捡收益。
精度损失
精度越低,近似误差通常越大。
而 K/V 又直接参与 attention:
- key 决定匹配权重
- value 决定聚合内容
所以量化过猛可能影响生成质量。
反量化开销
很多实现不会让量化后的值永久直接参与所有计算,而会在一定步骤中做反量化或混合精度处理。
这意味着:
- 你省了内存
- 但也可能引入额外计算和实现复杂度
不同硬件友好度不同
某些硬件和 kernel 对特定精度格式更友好。
所以 KV 量化 的真实收益往往既取决于算法,也取决于实现栈。
什么时候优先考虑哪种优化
如果主要瓶颈是 cache 过大:GQA / MQA、KV 量化
如果你担心表达能力损失太大:GQA
如果你不想动模型结构,只想省显存:KV 量化
现代推理系统里的 KV Cache:Paged Attention、Prompt Cache 与 Continuous Batching
真实服务里经常同时发生这些事:
- 很多请求一起到来
- 每个请求 prompt 长度不同
- 每个请求生成长度不同
- 有些请求共享相同前缀
- 请求会动态进出 batch
这时“只知道有 KV Cache”还远远不够。
还必须理解:
- 如何组织 cache 才不浪费内存
- 如何复用共享前缀
- 如何让不同请求一起高效解码
从单请求到多请求,问题为什么会突然变复杂
在单请求视角里,cache 似乎只是:一个随着时间增长的历史 K/V 仓库
但多请求一来,系统立刻会遇到更多约束:
长度不一致
请求 A 的序列长度可能是 300。
请求 B 的序列长度可能是 4000。
请求 C 的序列长度可能是 12000。
它们的 cache 大小完全不同。
生命周期不一致
有的请求刚开始 prefill。
有的请求已经在 decode 第 50 步。
有的请求下一步就会结束。
前缀可能重复
例如多个用户都在问同样的系统提示词或模板化 prompt:
You are a helpful assistant...如果每个请求都重新构建一遍相同前缀的 cache,就会浪费很多计算和内存。
显存是离散分配的
不是说“想用多少就精确切多少”,系统通常要面对:
- 内存块
- 分配和回收
- 碎片化
所以一旦进入服务层,真正的问题就变成:
如何让大量动态请求共享、复用并高效管理 KV Cache。
Paged Attention 想解决什么问题
先从最容易理解的痛点开始:
- cache 很大
- 请求长度不一致
- 动态分配和回收会带来内存碎片
如果用最朴素的连续内存方式存每个请求的整段 cache,可能会出现:
- 有些请求结束后,留下很多大小不一的空洞
- 新请求来了,不一定能正好塞进这些空洞
- 显存明明还有余量,但因为碎片化,实际很难高效利用
Paged Attention 的核心思想可以粗略理解成:
不把每个请求的 KV Cache 视为必须连续存放的一整段,而是切分成固定大小的块来管理。
这个思路和操作系统里“页”的概念很像,所以叫 paged。
用“书页”类比 Paged Attention
把一个请求的历史 cache 想成一本书。
要求这本书的所有页必须连续摆在同一块大桌子上。
问题是:
- 桌子空间会被零碎切开
- 新书大小不一,很难总找到刚好连续的空位
Paged 方式改成:
- 把书拆成固定大小的页块
- 每一页可以放在不同位置
- 系统维护“逻辑页号 -> 物理页块”的映射
所以 Paged Attention 的重点不是“改变注意力公式”,而是改变 cache 的内存组织方式
Paged Attention 本质上优化的是什么
它主要优化的是:
- cache 的分配
- cache 的回收
- cache 的布局
- 多请求下的内存利用率
它并不直接改变:
- decoder-only 自回归生成的定义
- KV Cache 的基本含义
- attention 公式本身
所以其本质为:系统层 / 内存管理层优化
Prompt Cache / Prefix Cache 想解决什么问题
现在来看另一个常见优化:
prompt cache- 或者
prefix cache
它们关注的问题是:
如果多个请求有相同前缀,为什么还要每次都重新 prefill 一遍?
举个例子。
假设很多请求都有同样的系统提示词:
You are an expert coding assistant...如果每次来新请求,都要对这段前缀做完整 prefill,那么会浪费:
- 计算
- cache 空间
于是一个自然思路是:
- 先把这段前缀对应的历史
K/V保留下来 - 新请求如果前缀完全一样,就直接复用
Prompt Cache 的收益来自哪里
收益主要来自两个方面。
- 减少重复 prefill
相同 prompt 不必每次从头构建一遍历史 cache。
- 降低首 token 延迟
因为一部分 prefill 工作已经提前做过了。
所以它最典型的收益场景是:
- 系统提示固定
- 模板化 prompt 多
- 多轮对话中前缀重复度高
Prompt Cache 的限制
它并不是“凡是像一点就能复用”。
最直接的限制是:
- 前缀必须足够一致
在最保守的实现中,往往要求:
- token 序列完全一致
如果前缀只差一个 token:
- 之前的 prefix cache 可能就不能直接原样复用
此外,它还会遇到:
- cache 淘汰策略
- 多租户隔离
- 一致性与命中率权衡
所以 prompt cache 非常有用,但也高度依赖实际业务分布。
Continuous Batching 想解决什么问题
接下来是服务层另一个非常重要的概念:
continuous batching
它关注的问题是:
请求不会整齐地同时开始和同时结束,系统怎么让 GPU 一直保持高利用率?
如果只做静态 batch,通常会很浪费:
- 一批请求一起进来
- 等到其中长请求拖很久时,短请求已经结束
- 很多算力会因为等待和填充而被浪费
continuous batching 的核心思路是:
- 不要求所有请求固定成一批到底
- 请求可以动态加入和退出 batch
- 每一轮 decode 都尽量把当前可执行的请求拼起来一起跑
这是一种调度方式上的优化。
为什么 Continuous Batching 和 KV Cache 紧密相关
因为动态 batch 运行时,系统必须持续管理:
- 哪些请求还活着
- 每个请求当前的序列长度是多少
- 每个请求的 cache 在哪里
- 哪些请求新加入,需要先 prefill
- 哪些请求结束,可以释放 cache
所以 continuous batching 的难点并不只是“把请求凑一起算”,而是 如何在动态调度下管理各自的 KV Cache。这也是为什么现代高性能推理服务通常会把cache 管理和batch 调度看成一个联动问题。
一个直观的多请求时间线
假设有三个请求:
- 请求 A:先到,生成很长
- 请求 B:后到,很快结束
- 请求 C:再后到,中等长度
如果没有动态调度,系统可能要么:
- 让 B 和 C 等 A
- 要么反复重新组 batch,代价很高
而 continuous batching 更像:
- 第 1 轮:只跑 A
- 第 2 轮:A + B
- 第 3 轮:A + B + C
- 第 4 轮:A + C,因为 B 已结束
整个过程中:
- 每个请求的 KV Cache 都在独立增长
- 系统需要能灵活找到并访问它们各自的 cache
这时 paged attention 的内存组织优势就会变得非常重要。
Hugging Face Transformers 的缓存文档
Caching
想象一下,你正在和某人对话,但对方并不会记住之前说过的话,而是在你每次回应之后,都要重新从头开始。这显然会很慢,也很低效。Transformer 的自回归生成可以用同样的类比来理解。模型是一个 token 一个 token 往后预测的,而每一步新预测都依赖此前已经出现的全部上下文。要预测第 1000 个 token,模型需要前 999 个 token 里的信息;要预测第 1001 个 token,它又需要前 1000 个 token 的信息。这意味着同一批历史信息会被一遍又一遍地参与计算,其中大量开销都来自对 token 表示的矩阵运算。KV Cache 的目的,就是把已经处理过的 token 在注意力层里得到的 key-value 对保存起来,后面直接取出复用,从而避免重复计算。
Cache 只应该用于推理。
如果在训练时启用 cache,可能会引发意料之外的问题。
Attention matrices
要理解 cache 为什么可行,最好先看 attention 矩阵本身。对于 batch size 为 b、注意力头数为 h、当前序列长度为 T、每个 head 维度为 d_head 的情形,scaled dot-product attention 可以写成:
其中,query、key、value 都是从输入 embedding 投影得到的,它们的常见形状为:
在 causal attention 里,mask 会阻止模型看见未来 token。于是,一个 token 一旦被处理完,它在面对未来 token 时的表示就不再发生变化,过去 token 对应的 K 和 V 也就可以缓存下来,在后续步骤里继续使用。如果当前正在处理第 t 个 token,那么这一时刻的 attention 可以理解为:
在推理阶段,每一步真正新增的是当前 token 对应的那一小部分信息。新产生的 key 和 value 会被写入 cache,并接到历史 key/value 后面:
attention 在模型的每一层里都是独立计算的,所以 cache 也是逐层维护的。不使用 cache 时,每一步都要重新得到此前 token 的 K/V;使用 cache 时,每一步通常只需要为当前 token 计算新的 K/V,历史部分直接复用。这样做的结果是,随着序列变长,推理里的重复工作会显著减少。
| without caching | with caching |
|---|---|
每一步都重新计算此前所有 token 的 K 和 V | 每一步只计算当前 token 的 K 和 V |
| 每步 attention 的代价会随着序列长度快速上升 | 每步计算更容易控制,虽然内存仍会随长度增长 |
Cache class
从接口角度看,一个最基础的 KV cache 会接收当前 token 的 key 和 value,然后返回更新后的 K/V。这件事通常由模型的 forward 方法在内部处理。
new_K, new_V = cache.update(k_t, v_t, layer_idx)
attn_output = attn_layer_idx_fn(q_t, new_K, new_V)当使用 Transformers 的 Cache 类时,自注意力模块会把过去和现在的信息拼到一起参与计算。当前 token 的 kv 会和 cache 中已经保存的历史 kv 拼接,这样 attention 分数就能同时基于旧上下文和新输入来计算。如果你是反复调用 forward 自己写 generation loop,那么还需要特别注意 attention mask 的长度必须和“历史 kv + 当前 kv”拼接后的总长度一致。generate() 通常会替你处理这些细节,但手写循环时就需要自己保证 mask 同时覆盖过去 token 和当前 token。
Cache storage implementation
在实现上,cache 可以理解成一个按层组织的结构,每一层都有自己的 key cache 和 value cache,而它们的常见形状为:
不同层可以使用不同类型的 cache layer,例如 DynamicLayer、StaticLayer 和 StaticSlidingWindowLayer。它们最主要的差别,在于序列长度如何管理,以及新 token 到来时 cache 怎样更新。最容易理解的是 DynamicLayer,它会随着新 token 的到来不断增长,也就是说 seq_len 会持续增加,更新方式可以理解成沿着序列维不断追加:
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)StaticLayer 和 StaticSlidingWindowLayer 则会在创建 cache 时就把长度固定下来,因此更容易和 torch.compile 这类静态形状优化配合使用。对于 StaticSlidingWindowLayer,当新 token 进入而窗口又已经满了时,较早的 token 会被移出 cache。
下面是一个用 DynamicCache 手写生成循环的例子。关键点不在某个具体模型,而在调用方式本身:第一次把整段输入送进模型,之后每一轮通常只送入尚未处理的新 token,同时持续传入 past_key_values,并扩展 attention_mask,让它始终覆盖历史 token 和当前新增 token。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from accelerate import Accelerator
device = Accelerator().device
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache(config=model.config)
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
generated_ids = inputs.input_ids
max_new_tokens = 10
for _ in range(max_new_tokens):
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
next_token_ids = outputs.logits[:, -1:].argmax(-1)
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
attention_mask = inputs["attention_mask"]
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
dim=-1,
)
inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}