Skip to content

Transformer

RNN 的主要问题:

  • 梯度消失/爆炸:长距离依赖难以学习
  • 顺序计算:无法并行处理序列
  • 信息瓶颈:最后时刻隐藏状态需承载全部信息

Transformer 的改进:

  • 并行计算:同时处理整个序列
  • 自注意力机制:直接建立任意位置间的联系
  • 位置编码:显式注入位置信息

结构

Transformer

图:Transformer 单元

Transformer

图:Transformer 的详细结构

  • 输入
    • 编码器输入
    • 解码器输入
  • 输出
    • 线性层
    • Softmax 层
  • 编码器
    • 由 N 个编码器层堆叠而成
    • 每个编码器层由两个子层连接结构组成
    • 第一个子层连接结构包括一个多头自注意力子层规范化层以及一个残差连接
    • 第二个子层连接结构包括一个前馈全连接子层规范化层以及一个残差连接
  • 解码器
    • 由 N 个解码器层堆叠而成
    • 每个解码器层由三个子层连接结构组成
    • 第一个子层连接结构包括一个带掩码的-多头自注意力子层规范化层以及一个残差连接
    • 第二个子层连接结构包括一个多头注意力子层(编码器到解码器)和规范化层以及一个残差连接
    • 第三个子层连接结构包括一个前馈全连接子层规范化层以及一个残差连接

核心组件

自注意力机制(Self-Attention)

在传统的神经网络处理序列时,模型只能一步步按顺序处理,难以捕捉长距离依赖关系。自注意力机制就是为了让序列中的每个元素都能直接与序列中所有其他元素进行交互,无论它们直接的距离多远。

定义:

符号维度含义
Xn×d输入矩阵(n=序列长度,d=特征维度)
Qn×dkQuery 矩阵(查询向量)
Kn×dkKey 矩阵(键向量)
Vn×dvValue 矩阵(值向量)
WQ,WK,WVd×dk/dv可学习参数矩阵
Attention(Q,K,V)=Softmax(QKTdk)V

TIP

  • Q 表示当前需要关注的信息或问题,用于确定输入序列中哪些部分与当前任务相关
  • K 用于匹配查询,通过计算相似度判断输入序列中哪些元素与查询匹配
  • V 存储实际信息

推导过程

将输入转换为 Query、Key、Value:

Q=XWQ,K=XWK,V=XWV

计算注意力分数:

Scores=QKTdk

生成注意力权重矩阵:

A=Softmax(Scores)

得到最终注意力输出:

Output=AV

带掩码自注意力层(Masked Multi-head attention)

编码时,对于 t 时刻的预测,我们知道 x1,x2,,xt,xt+1,,xT 全部的信息。

解码时,对于 t 时刻的预测,我们仅知道 x1,x2,,xt1 的信息。看不到后续的信息,因此需要将后续的信息遮掩起来。

Attention(Q,K,V)=Softmax(QKTMdk)V

多头注意力(Multi-Head Attention)

Transformer

MultiHead(Q,K,V)=Concat(head1,,headh)WO

其中:

headi=Attention(QWiQ,KWiK,VWiV)

位置编码(Positional Encoding)

RNNLSTM 等顺序算法不同,Transformer 没有内置机制来捕获句子中单词的相对位置,所以在 Transformerencoderdecoder 的输入层中,使用了 Positional Encoding,使得最终的输入满足:

input=input_embedding+positional_encoding

原始正弦编码公式:

PE(pos,2i)=sin(pos100002i/d)PE(pos,2i+1)=cos(pos100002i/d)

前馈网络(Feed Forward Network)

包括两个线性变换+ReLU 激活:

FFN(x)=ReLU(xW1+b1)W2+b2

计算复杂度

当输入批次大小为 b,序列长度为 N,词向量的维度(隐藏层的维度)为 d 时,l 层 transformer 的计算复杂度:

Self-Attention 层

FLOPs(Self-Attention)=8bNd2+4bN2d
  1. 计算 QKV

输入输出

[b,N,d]×[d,d][b,N,d]

计算量为:

FLOPs=3QKVb(Ndd[N,d]×[d,d]乘法+Ndd[N,d]×[d,d]加法)=6bNd2

TIP

矩阵加法运算考虑偏差 bias计算量就是 Ndd,如果不考虑偏差就是 Nd(d1),但这个 1 一般忽略不计。

  1. 计算 QKT

输入输出

[b,h,N,dk]×[b,h,dk,N][b,h,N,N]

h 为注意力头数,dk 为每个头的维度,hdk=d

FLOPs=bh(N2dk+N2dk)=2bN2d
  1. Softmax 与加权求和

Softmax 计算量较小,通常忽略。

输入输出

[b,h,N,N]×[b,h,N,dk][b,h,N,dk]
FLOPs=bh(NdkN+NdkN)=2bN2d
  1. 输出投影

线性变换将结果映射回 n 维:

输入输出

[b,N,d]×[d,d][b,N,d]
FLOPs=2bNd2

MLP 层

FLOPs(MLP)=16bNd2
  1. 线性层(扩展层)

输入输出

[b,N,d]×[d,4d][b,N,4d]
FLOPs=8bNd2
  1. 线性层(压缩层)

输入输出

[b,N,4d]×[4d,d][b,N,d]
FLOPs=8bNd2

logits

Logits 层是将最终的 Transformer 隐藏层输出(维度 d)映射到词表大小 V,即一个线性投影:

输入输出

[b,N,d]×[d,V][b,N,V]
FLOPs(logits)=2bNdV

总的计算复杂度

FLOPs(Transformer)=l(24bNd2+4bN2d)+2bNdV

空间复杂度

大模型在训练过程中通常采用混合精度训练,中间激活值一般是 float16 或者 bfloat16 数据类型的。在分析中间激活的显存占用时,假设中间激活值是以 float16 或 bfloat16 数据格式来保存的,每个元素占了 2 个 bytes,dropout 操作的 mask 矩阵,每个元素只占 1 个 bytes。需要保存的中间激活占用显存大小计算如下:

Self-Attention 层

  1. QKV 共享一个输入 X,则显存占用为 2bNd
  2. 对于 QKT,两个张量形状都是 [b,N,d],显存占用为 4bNd
  3. 对于 Softmax,函数输入 QKT 形状为 [b,h,N,N],显存占用为 2bN2h
  4. 计算完 Softmax,会进行 dropout,需要保存一个 mask 矩阵,其形状与 QKT 相同,显存占用为 bN2h
  5. 计算 ScoresV,二者占用显存大小为 2bN2h+2bNd
  6. 计算输出映射和一个 dropout 操作,二者占用显存大小为 2bNd+bNd

综上,Self-Attention 层的显存占用为 11bNd+5bN2a

MLP 层

  1. 第一个线性层的输入占用显存 2bNd
  2. 激活函数的输入占用显存 8bNd
  3. 第二个线性层的输入占用显存 8bNd
  4. 最后的 dropout 操作需要保存的 mask 矩阵占用显存 bNd

综上,MLP 层的显存占用为 19bNd

LN

Self-Attention 层和 MLP 层分别对应了一个 LN,其输入占用显存为 2bNd+2bNd

总的空间复杂度

l(34bNd+5bN2h)

问题

Transformer 的计算复杂度为:l(24bNd2+4bN2d)+2bNdV,需要保存的中间激活占用显存大小为:l(34bNd+5bN2h),即 Transformer 模型的计算量和储存复杂度随着序列长度 N 呈二次方增长。

可以注意到,4bN2d5bN2h 均产生于 Self-Attention 层。

架构

NLP/LLM 最经典的三类 Transformer 架构是 encoder-only、decoder-only 和 encoder-decoder。

  • encoder-only:双向看上下文,擅长分类、匹配、抽取和表示学习。
  • decoder-only:只能看左侧历史,天然适合 next-token prediction 和开放生成。
  • encoder-decoder:先编码输入,再在 cross-attention 条件下生成输出,适合翻译、摘要、问答。
  • causal maskbidirectional maskcross-attention 都是在描述这三类结构的差异。
  • 现代 LLM 主流是 decoder-only,但检索、reranker 和很多 NLP 任务仍常用 encoder。

代码

nn.Transformer

只负责主体结构

python
import torch
import torch.nn as nn

transformer = nn.Transformer(
    d_model=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    batch_first=False,
)
text
d_model             每个 token 的特征维度
nhead               多头注意力的 head 数量
num_encoder_layers  encoder 堆叠多少层
num_decoder_layers  decoder 堆叠多少层
dim_feedforward     FFN 中间层维度
dropout             dropout 概率
activation          FFN 中间激活函数,默认 relu,也支持 gelu
batch_first         是否使用 [batch, seq, feature]
norm_first          LayerNorm 放在 attention/FFN 之前还是之后
bias                Linear 和 LayerNorm 是否使用 bias

__init__:先创建 Encoder,再创建 Decoder

python
class Transformer(nn.Module):
    def __init__(
        self,
        d_model=512,
        nhead=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        dim_feedforward=2048,
        dropout=0.1,
        activation=F.relu,
        custom_encoder=None,
        custom_decoder=None,
        layer_norm_eps=1e-5,
        batch_first=False,
        norm_first=False,
        bias=True,
        device=None,
        dtype=None,
    ):
        super().__init__()

        if custom_encoder is not None:
            self.encoder = custom_encoder
        else:
            encoder_layer = TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                activation=activation,
                layer_norm_eps=layer_norm_eps,
                batch_first=batch_first,
                norm_first=norm_first,
                bias=bias,
            )

            encoder_norm = LayerNorm(
                d_model,
                eps=layer_norm_eps,
                bias=bias,
            )

            self.encoder = TransformerEncoder(
                encoder_layer,
                num_encoder_layers,
                encoder_norm,
            )

        if custom_decoder is not None:
            self.decoder = custom_decoder
        else:
            decoder_layer = TransformerDecoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                activation=activation,
                layer_norm_eps=layer_norm_eps,
                batch_first=batch_first,
                norm_first=norm_first,
                bias=bias,
            )

            decoder_norm = LayerNorm(
                d_model,
                eps=layer_norm_eps,
                bias=bias,
            )

            self.decoder = TransformerDecoder(
                decoder_layer,
                num_decoder_layers,
                decoder_norm,
            )

        self._reset_parameters()

        self.d_model = d_model
        self.nhead = nhead
        self.batch_first = batch_first

_reset_parameters():初始化所有矩阵参数

python
def _reset_parameters(self):
    for p in self.parameters():
        if p.dim() > 1:
            xavier_uniform_(p)

含义是:

text
二维或更高维参数:重新 Xavier uniform 初始化
一维参数:跳过,比如 bias、LayerNorm weight/bias

比如:

python
nn.Linear(512, 2048)

它的权重矩阵 shape 是 [2048, 512]

所以:

text
fan_in = 512
fan_out = 2048

Xavier uniform 的范围是:

WU(a,a)

其中:

a=6fanin+fanout

也就是:

a=6512+20480.0484

所以权重大概从 [-0.0484, 0.0484] 里采样。

forward():先 Encoder,再 Decoder

nn.Transformer.forward() 简化版:

python
def forward(
    self,
    src,
    tgt,
    src_mask=None,
    tgt_mask=None,
    memory_mask=None,
    src_key_padding_mask=None,
    tgt_key_padding_mask=None,
    memory_key_padding_mask=None,
    src_is_causal=None,
    tgt_is_causal=None,
    memory_is_causal=False,
):
    is_batched = src.dim() == 3

    if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
        raise RuntimeError("the batch number of src and tgt must be equal")

    elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
        raise RuntimeError("the batch number of src and tgt must be equal")

    if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
        raise RuntimeError("the feature number of src and tgt must be equal to d_model")

    memory = self.encoder(
        src,
        mask=src_mask,
        src_key_padding_mask=src_key_padding_mask,
        is_causal=src_is_causal,
    )

    output = self.decoder(
        tgt,
        memory,
        tgt_mask=tgt_mask,
        memory_mask=memory_mask,
        tgt_key_padding_mask=tgt_key_padding_mask,
        memory_key_padding_mask=memory_key_padding_mask,
        tgt_is_causal=tgt_is_causal,
        memory_is_causal=memory_is_causal,
    )

    return output

Shape 规则

设:

text
S = source sequence length
T = target sequence length
N = batch size
E = feature dimension = d_model

当:

python
batch_first=False

输入输出是:

python
src:    [S, N, E]
tgt:    [T, N, E]
memory: [S, N, E]
output: [T, N, E]

当:

python
batch_first=True

输入输出是:

python
src:    [N, S, E]
tgt:    [N, T, E]
memory: [N, S, E]
output: [N, T, E]

TransformerEncoder:复制 N 个 EncoderLayer

简化版:

python
class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()

        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None, is_causal=None):
        output = src

        for mod in self.layers:
            output = mod(
                output,
                src_mask=mask,
                src_key_padding_mask=src_key_padding_mask,
                is_causal=is_causal,
            )

        if self.norm is not None:
            output = self.norm(output)

        return output

所以 encoder 的逻辑非常直接:

text
output = src
for layer in encoder_layers:
    output = layer(output)
output = final_norm(output)

如果 num_encoder_layers=6,大概就是:

text
src

EncoderLayer 1

EncoderLayer 2

EncoderLayer 3

EncoderLayer 4

EncoderLayer 5

EncoderLayer 6

encoder_norm

memory

TransformerDecoder:复制 N 个 DecoderLayer

decoder 的结构和 encoder 很像,只是每一层额外需要 memory

python
class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, norm=None):
        super().__init__()

        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(
        self,
        tgt,
        memory,
        tgt_mask=None,
        memory_mask=None,
        tgt_key_padding_mask=None,
        memory_key_padding_mask=None,
        tgt_is_causal=None,
        memory_is_causal=False,
    ):
        output = tgt

        for mod in self.layers:
            output = mod(
                output,
                memory,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=memory_key_padding_mask,
                tgt_is_causal=tgt_is_causal,
                memory_is_causal=memory_is_causal,
            )

        if self.norm is not None:
            output = self.norm(output)

        return output

TransformerEncoderLayer:Self-Attention + FFN

EncoderLayer 是 Transformer encoder 的基本单元。官方代码里,它主要包含这些模块:

python
self.self_attn = MultiheadAttention(d_model, nhead)

self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)

self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)

self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
text
x
 ├── Self-Attention
 ├── Add & Norm
 ├── Feed Forward Network
 └── Add & Norm

FFN 部分是:

python
linear2(dropout(activation(linear1(x))))

也就是:

text
d_model
  ↓ Linear
dim_feedforward
  ↓ activation
dim_feedforward
  ↓ Dropout
dim_feedforward
  ↓ Linear
d_model

例如默认参数:

python
d_model = 512
dim_feedforward = 2048

则 FFN 是:

text
512 -> 2048 -> 512

EncoderLayer 的 forward:Post-LN 和 Pre-LN

python
x = src

if self.norm_first:
    x = x + self._sa_block(
        self.norm1(x),
        src_mask,
        src_key_padding_mask,
        is_causal=is_causal,
    )
    x = x + self._ff_block(self.norm2(x))

else:
    x = self.norm1(
        x + self._sa_block(
            x,
            src_mask,
            src_key_padding_mask,
            is_causal=is_causal,
        )
    )
    x = self.norm2(x + self._ff_block(x))

return x

也就是说:

默认:norm_first=False

这是 Post-LN:

text
x = Norm(x + SelfAttention(x))
x = Norm(x + FFN(x))

如果:norm_first=True

这是 Pre-LN:

text
x = x + SelfAttention(Norm(x))
x = x + FFN(Norm(x))

即:

Post-LN:先残差相加,再 LayerNorm

Pre-LN:先 LayerNorm,再进子层,最后残差相加

EncoderLayer 里的 _sa_block

简化版:

python
def _sa_block(self, x, attn_mask, key_padding_mask, is_causal=False):
    x = self.self_attn(
        x, x, x,
        attn_mask=attn_mask,
        key_padding_mask=key_padding_mask,
        need_weights=False,
        is_causal=is_causal,
    )[0]

    return self.dropout1(x)

关键是:

python
self.self_attn(x, x, x)

也就是:

text
query = x
key   = x
value = x

所以 encoder layer 里的 attention 是 self-attention

EncoderLayer 里的 _ff_block

简化版:

python
def _ff_block(self, x):
    x = self.linear2(
        self.dropout(
            self.activation(
                self.linear1(x)
            )
        )
    )

    return self.dropout2(x)

也就是:

text
x

Linear(d_model -> dim_feedforward)

activation

Dropout

Linear(dim_feedforward -> d_model)

Dropout

TransformerDecoderLayer:Self-Attention + Cross-Attention + FFN

DecoderLayer 比 EncoderLayer 多一个 cross-attention。官方代码里主要包含:

python
self.self_attn = MultiheadAttention(d_model, nhead)
self.multihead_attn = MultiheadAttention(d_model, nhead)

self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)

self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)

self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)

结构:

text
tgt
 ├── masked self-attention over tgt
 ├── cross-attention over encoder memory
 ├── feed-forward network
 └── output

DecoderLayer 的 forward

简化版:

python
x = tgt

if self.norm_first:
    x = x + self._sa_block(
        self.norm1(x),
        tgt_mask,
        tgt_key_padding_mask,
        tgt_is_causal,
    )

    x = x + self._mha_block(
        self.norm2(x),
        memory,
        memory_mask,
        memory_key_padding_mask,
        memory_is_causal,
    )

    x = x + self._ff_block(self.norm3(x))

else:
    x = self.norm1(
        x + self._sa_block(
            x,
            tgt_mask,
            tgt_key_padding_mask,
            tgt_is_causal,
        )
    )

    x = self.norm2(
        x + self._mha_block(
            x,
            memory,
            memory_mask,
            memory_key_padding_mask,
            memory_is_causal,
        )
    )

    x = self.norm3(x + self._ff_block(x))

return x

默认 norm_first=False,所以是:

text
x = Norm(x + SelfAttention(x))
x = Norm(x + CrossAttention(x, memory))
x = Norm(x + FFN(x))

如果 norm_first=True,就是:

text
x = x + SelfAttention(Norm(x))
x = x + CrossAttention(Norm(x), memory)
x = x + FFN(Norm(x))

DecoderLayer 里的 _sa_block

decoder 的 self-attention 是:

python
def _sa_block(self, x, attn_mask, key_padding_mask, is_causal=False):
    x = self.self_attn(
        x, x, x,
        attn_mask=attn_mask,
        key_padding_mask=key_padding_mask,
        is_causal=is_causal,
        need_weights=False,
    )[0]

    return self.dropout1(x)

这里还是:

text
query = x
key   = x
value = x

但是 decoder 训练时通常要加 causal mask,避免当前位置看到未来 token。

也就是:

text
第 1 个 token 只能看第 1 个
第 2 个 token 可以看第 1~2 个
第 3 个 token 可以看第 1~3 个
...

DecoderLayer 里的 _mha_block:cross-attention

cross-attention 是 decoder layer 最关键的区别:

python
def _mha_block(self, x, mem, attn_mask, key_padding_mask, is_causal=False):
    x = self.multihead_attn(
        x, mem, mem,
        attn_mask=attn_mask,
        key_padding_mask=key_padding_mask,
        is_causal=is_causal,
        need_weights=False,
    )[0]

    return self.dropout2(x)

这里是:

text
query = x       # decoder 当前 hidden states
key   = mem     # encoder output
value = mem     # encoder output

所以 decoder 的每个位置会用自己的 hidden state 去查询 encoder 的输出。

这就是为什么 decoder forward 需要两个输入:

python
decoder(tgt, memory)

其中:

text
tgt    = decoder 当前输入
memory = encoder 对 src 的编码结果

causal mask:防止 decoder 看未来

python
nn.Transformer.generate_square_subsequent_mask(sz)

它会生成一个方阵 causal mask。

例如:

python
mask = nn.Transformer.generate_square_subsequent_mask(4)

大概是:

text
[[0,   -inf, -inf, -inf],
 [0,    0,   -inf, -inf],
 [0,    0,    0,   -inf],
 [0,    0,    0,    0  ]]

含义是:

text
第 0 个位置不能看 1、2、3
第 1 个位置不能看 2、3
第 2 个位置不能看 3
第 3 个位置可以看 0、1、2、3

一个完整的最小例子

python
import torch
import torch.nn as nn

N = 32   # batch size
S = 10   # src sequence length
T = 20   # tgt sequence length
E = 512  # d_model

transformer = nn.Transformer(
    d_model=E,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dim_feedforward=2048,
    dropout=0.1,
    batch_first=True,
)

src = torch.randn(N, S, E)
tgt = torch.randn(N, T, E)

tgt_mask = nn.Transformer.generate_square_subsequent_mask(
    T,
    device=tgt.device,
)

out = transformer(
    src,
    tgt,
    tgt_mask=tgt_mask,
)

print(out.shape)
# [N, T, E] 
# [32, 20, 512]

压缩版

python
class Transformer(nn.Module):
    def __init__(self):
        self.encoder = TransformerEncoder(
            TransformerEncoderLayer(...),
            num_encoder_layers,
            encoder_norm,
        )

        self.decoder = TransformerDecoder(
            TransformerDecoderLayer(...),
            num_decoder_layers,
            decoder_norm,
        )

    def forward(self, src, tgt):
        memory = self.encoder(src)
        output = self.decoder(tgt, memory)
        return output
python
class TransformerEncoderLayer(nn.Module):
    def forward(self, src):
        src = src + SelfAttention(src)
        src = src + FFN(src)
        return src
python
class TransformerDecoderLayer(nn.Module):
    def forward(self, tgt, memory):
        tgt = Norm(tgt + SelfAttention(tgt))
        tgt = Norm(tgt + CrossAttention(tgt, memory))
        tgt = Norm(tgt + FFN(tgt))
        return tgt