Transformer
RNN 的主要问题:
- 梯度消失/爆炸:长距离依赖难以学习
- 顺序计算:无法并行处理序列
- 信息瓶颈:最后时刻隐藏状态需承载全部信息
Transformer 的改进:
- 并行计算:同时处理整个序列
- 自注意力机制:直接建立任意位置间的联系
- 位置编码:显式注入位置信息
结构
图:Transformer 单元
图:Transformer 的详细结构
- 输入
- 编码器输入
- 解码器输入
- 输出
- 线性层
- Softmax 层
- 编码器
- 由 N 个编码器层堆叠而成
- 每个编码器层由两个子层连接结构组成
- 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
- 第二个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接
- 解码器
- 由 N 个解码器层堆叠而成
- 每个解码器层由三个子层连接结构组成
- 第一个子层连接结构包括一个带掩码的-多头自注意力子层和规范化层以及一个残差连接
- 第二个子层连接结构包括一个多头注意力子层(编码器到解码器)和规范化层以及一个残差连接
- 第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接
核心组件
自注意力机制(Self-Attention)
在传统的神经网络处理序列时,模型只能一步步按顺序处理,难以捕捉长距离依赖关系。自注意力机制就是为了让序列中的每个元素都能直接与序列中所有其他元素进行交互,无论它们直接的距离多远。
定义:
| 符号 | 维度 | 含义 |
|---|---|---|
| 输入矩阵(n=序列长度,d=特征维度) | ||
| Query 矩阵(查询向量) | ||
| Key 矩阵(键向量) | ||
| Value 矩阵(值向量) | ||
| 可学习参数矩阵 |
TIP
表示当前需要关注的信息或问题,用于确定输入序列中哪些部分与当前任务相关 用于匹配查询,通过计算相似度判断输入序列中哪些元素与查询匹配 存储实际信息
推导过程
将输入转换为 Query、Key、Value:
计算注意力分数:
生成注意力权重矩阵:
得到最终注意力输出:
带掩码自注意力层(Masked Multi-head attention)
编码时,对于
解码时,对于
多头注意力(Multi-Head Attention)
其中:
位置编码(Positional Encoding)
与 RNN 和 LSTM 等顺序算法不同,Transformer 没有内置机制来捕获句子中单词的相对位置,所以在 Transformer 的 encoder 和 decoder 的输入层中,使用了 Positional Encoding,使得最终的输入满足:
原始正弦编码公式:
前馈网络(Feed Forward Network)
包括两个线性变换+ReLU 激活:
计算复杂度
当输入批次大小为
Self-Attention 层
- 计算
、 、
输入输出
计算量为:
TIP
矩阵加法运算考虑偏差 bias计算量就是
- 计算
输入输出
- Softmax 与加权求和
Softmax 计算量较小,通常忽略。
输入输出
- 输出投影
线性变换将结果映射回
输入输出
MLP 层
- 线性层(扩展层)
输入输出
- 线性层(压缩层)
输入输出
logits
Logits 层是将最终的 Transformer 隐藏层输出(维度
输入输出
总的计算复杂度
空间复杂度
大模型在训练过程中通常采用混合精度训练,中间激活值一般是 float16 或者 bfloat16 数据类型的。在分析中间激活的显存占用时,假设中间激活值是以 float16 或 bfloat16 数据格式来保存的,每个元素占了 2 个 bytes,dropout 操作的 mask 矩阵,每个元素只占 1 个 bytes。需要保存的中间激活占用显存大小计算如下:
Self-Attention 层
、 、 共享一个输入 ,则显存占用为 - 对于
,两个张量形状都是 ,显存占用为 - 对于
,函数输入 形状为 ,显存占用为 - 计算完
,会进行 dropout,需要保存一个 mask 矩阵,其形状与 相同,显存占用为 - 计算
,二者占用显存大小为 - 计算输出映射和一个 dropout 操作,二者占用显存大小为
综上,Self-Attention 层的显存占用为
MLP 层
- 第一个线性层的输入占用显存
- 激活函数的输入占用显存
- 第二个线性层的输入占用显存
- 最后的 dropout 操作需要保存的 mask 矩阵占用显存
综上,MLP 层的显存占用为
LN
Self-Attention 层和 MLP 层分别对应了一个 LN,其输入占用显存为
总的空间复杂度
问题
Transformer 的计算复杂度为:
可以注意到,
架构
NLP/LLM 最经典的三类 Transformer 架构是 encoder-only、decoder-only 和 encoder-decoder。
- encoder-only:双向看上下文,擅长分类、匹配、抽取和表示学习。
- decoder-only:只能看左侧历史,天然适合 next-token prediction 和开放生成。
- encoder-decoder:先编码输入,再在 cross-attention 条件下生成输出,适合翻译、摘要、问答。
causal mask、bidirectional mask、cross-attention都是在描述这三类结构的差异。- 现代 LLM 主流是 decoder-only,但检索、reranker 和很多 NLP 任务仍常用 encoder。
代码
nn.Transformer
只负责主体结构
import torch
import torch.nn as nn
transformer = nn.Transformer(
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
batch_first=False,
)d_model 每个 token 的特征维度
nhead 多头注意力的 head 数量
num_encoder_layers encoder 堆叠多少层
num_decoder_layers decoder 堆叠多少层
dim_feedforward FFN 中间层维度
dropout dropout 概率
activation FFN 中间激活函数,默认 relu,也支持 gelu
batch_first 是否使用 [batch, seq, feature]
norm_first LayerNorm 放在 attention/FFN 之前还是之后
bias Linear 和 LayerNorm 是否使用 bias__init__:先创建 Encoder,再创建 Decoder
class Transformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation=F.relu,
custom_encoder=None,
custom_decoder=None,
layer_norm_eps=1e-5,
batch_first=False,
norm_first=False,
bias=True,
device=None,
dtype=None,
):
super().__init__()
if custom_encoder is not None:
self.encoder = custom_encoder
else:
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=batch_first,
norm_first=norm_first,
bias=bias,
)
encoder_norm = LayerNorm(
d_model,
eps=layer_norm_eps,
bias=bias,
)
self.encoder = TransformerEncoder(
encoder_layer,
num_encoder_layers,
encoder_norm,
)
if custom_decoder is not None:
self.decoder = custom_decoder
else:
decoder_layer = TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=batch_first,
norm_first=norm_first,
bias=bias,
)
decoder_norm = LayerNorm(
d_model,
eps=layer_norm_eps,
bias=bias,
)
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
self.batch_first = batch_first_reset_parameters():初始化所有矩阵参数
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)含义是:
二维或更高维参数:重新 Xavier uniform 初始化
一维参数:跳过,比如 bias、LayerNorm weight/bias比如:
nn.Linear(512, 2048)它的权重矩阵 shape 是 [2048, 512]
所以:
fan_in = 512
fan_out = 2048Xavier uniform 的范围是:
其中:
也就是:
所以权重大概从 [-0.0484, 0.0484] 里采样。
forward():先 Encoder,再 Decoder
nn.Transformer.forward() 简化版:
def forward(
self,
src,
tgt,
src_mask=None,
tgt_mask=None,
memory_mask=None,
src_key_padding_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
src_is_causal=None,
tgt_is_causal=None,
memory_is_causal=False,
):
is_batched = src.dim() == 3
if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
raise RuntimeError("the batch number of src and tgt must be equal")
elif self.batch_first and src.size(0) != tgt.size(0) and is_batched:
raise RuntimeError("the batch number of src and tgt must be equal")
if src.size(-1) != self.d_model or tgt.size(-1) != self.d_model:
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
memory = self.encoder(
src,
mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
is_causal=src_is_causal,
)
output = self.decoder(
tgt,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
tgt_is_causal=tgt_is_causal,
memory_is_causal=memory_is_causal,
)
return outputShape 规则
设:
S = source sequence length
T = target sequence length
N = batch size
E = feature dimension = d_model当:
batch_first=False输入输出是:
src: [S, N, E]
tgt: [T, N, E]
memory: [S, N, E]
output: [T, N, E]当:
batch_first=True输入输出是:
src: [N, S, E]
tgt: [N, T, E]
memory: [N, S, E]
output: [N, T, E]TransformerEncoder:复制 N 个 EncoderLayer
简化版:
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src, mask=None, src_key_padding_mask=None, is_causal=None):
output = src
for mod in self.layers:
output = mod(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
is_causal=is_causal,
)
if self.norm is not None:
output = self.norm(output)
return output所以 encoder 的逻辑非常直接:
output = src
for layer in encoder_layers:
output = layer(output)
output = final_norm(output)如果 num_encoder_layers=6,大概就是:
src
↓
EncoderLayer 1
↓
EncoderLayer 2
↓
EncoderLayer 3
↓
EncoderLayer 4
↓
EncoderLayer 5
↓
EncoderLayer 6
↓
encoder_norm
↓
memoryTransformerDecoder:复制 N 个 DecoderLayer
decoder 的结构和 encoder 很像,只是每一层额外需要 memory:
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
tgt_is_causal=None,
memory_is_causal=False,
):
output = tgt
for mod in self.layers:
output = mod(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
tgt_is_causal=tgt_is_causal,
memory_is_causal=memory_is_causal,
)
if self.norm is not None:
output = self.norm(output)
return outputTransformerEncoderLayer:Self-Attention + FFN
EncoderLayer 是 Transformer encoder 的基本单元。官方代码里,它主要包含这些模块:
self.self_attn = MultiheadAttention(d_model, nhead)
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)x
├── Self-Attention
├── Add & Norm
├── Feed Forward Network
└── Add & NormFFN 部分是:
linear2(dropout(activation(linear1(x))))也就是:
d_model
↓ Linear
dim_feedforward
↓ activation
dim_feedforward
↓ Dropout
dim_feedforward
↓ Linear
d_model例如默认参数:
d_model = 512
dim_feedforward = 2048则 FFN 是:
512 -> 2048 -> 512EncoderLayer 的 forward:Post-LN 和 Pre-LN
x = src
if self.norm_first:
x = x + self._sa_block(
self.norm1(x),
src_mask,
src_key_padding_mask,
is_causal=is_causal,
)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(
x + self._sa_block(
x,
src_mask,
src_key_padding_mask,
is_causal=is_causal,
)
)
x = self.norm2(x + self._ff_block(x))
return x也就是说:
默认:norm_first=False
这是 Post-LN:
x = Norm(x + SelfAttention(x))
x = Norm(x + FFN(x))如果:norm_first=True
这是 Pre-LN:
x = x + SelfAttention(Norm(x))
x = x + FFN(Norm(x))即:
Post-LN:先残差相加,再 LayerNorm
Pre-LN:先 LayerNorm,再进子层,最后残差相加
EncoderLayer 里的 _sa_block
简化版:
def _sa_block(self, x, attn_mask, key_padding_mask, is_causal=False):
x = self.self_attn(
x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=is_causal,
)[0]
return self.dropout1(x)关键是:
self.self_attn(x, x, x)也就是:
query = x
key = x
value = x所以 encoder layer 里的 attention 是 self-attention。
EncoderLayer 里的 _ff_block
简化版:
def _ff_block(self, x):
x = self.linear2(
self.dropout(
self.activation(
self.linear1(x)
)
)
)
return self.dropout2(x)也就是:
x
↓
Linear(d_model -> dim_feedforward)
↓
activation
↓
Dropout
↓
Linear(dim_feedforward -> d_model)
↓
DropoutTransformerDecoderLayer:Self-Attention + Cross-Attention + FFN
DecoderLayer 比 EncoderLayer 多一个 cross-attention。官方代码里主要包含:
self.self_attn = MultiheadAttention(d_model, nhead)
self.multihead_attn = MultiheadAttention(d_model, nhead)
self.linear1 = Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = Linear(dim_feedforward, d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.dropout3 = Dropout(dropout)结构:
tgt
├── masked self-attention over tgt
├── cross-attention over encoder memory
├── feed-forward network
└── outputDecoderLayer 的 forward
简化版:
x = tgt
if self.norm_first:
x = x + self._sa_block(
self.norm1(x),
tgt_mask,
tgt_key_padding_mask,
tgt_is_causal,
)
x = x + self._mha_block(
self.norm2(x),
memory,
memory_mask,
memory_key_padding_mask,
memory_is_causal,
)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(
x + self._sa_block(
x,
tgt_mask,
tgt_key_padding_mask,
tgt_is_causal,
)
)
x = self.norm2(
x + self._mha_block(
x,
memory,
memory_mask,
memory_key_padding_mask,
memory_is_causal,
)
)
x = self.norm3(x + self._ff_block(x))
return x默认 norm_first=False,所以是:
x = Norm(x + SelfAttention(x))
x = Norm(x + CrossAttention(x, memory))
x = Norm(x + FFN(x))如果 norm_first=True,就是:
x = x + SelfAttention(Norm(x))
x = x + CrossAttention(Norm(x), memory)
x = x + FFN(Norm(x))DecoderLayer 里的 _sa_block
decoder 的 self-attention 是:
def _sa_block(self, x, attn_mask, key_padding_mask, is_causal=False):
x = self.self_attn(
x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False,
)[0]
return self.dropout1(x)这里还是:
query = x
key = x
value = x但是 decoder 训练时通常要加 causal mask,避免当前位置看到未来 token。
也就是:
第 1 个 token 只能看第 1 个
第 2 个 token 可以看第 1~2 个
第 3 个 token 可以看第 1~3 个
...DecoderLayer 里的 _mha_block:cross-attention
cross-attention 是 decoder layer 最关键的区别:
def _mha_block(self, x, mem, attn_mask, key_padding_mask, is_causal=False):
x = self.multihead_attn(
x, mem, mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
need_weights=False,
)[0]
return self.dropout2(x)这里是:
query = x # decoder 当前 hidden states
key = mem # encoder output
value = mem # encoder output所以 decoder 的每个位置会用自己的 hidden state 去查询 encoder 的输出。
这就是为什么 decoder forward 需要两个输入:
decoder(tgt, memory)其中:
tgt = decoder 当前输入
memory = encoder 对 src 的编码结果causal mask:防止 decoder 看未来
nn.Transformer.generate_square_subsequent_mask(sz)它会生成一个方阵 causal mask。
例如:
mask = nn.Transformer.generate_square_subsequent_mask(4)大概是:
[[0, -inf, -inf, -inf],
[0, 0, -inf, -inf],
[0, 0, 0, -inf],
[0, 0, 0, 0 ]]含义是:
第 0 个位置不能看 1、2、3
第 1 个位置不能看 2、3
第 2 个位置不能看 3
第 3 个位置可以看 0、1、2、3一个完整的最小例子
import torch
import torch.nn as nn
N = 32 # batch size
S = 10 # src sequence length
T = 20 # tgt sequence length
E = 512 # d_model
transformer = nn.Transformer(
d_model=E,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
batch_first=True,
)
src = torch.randn(N, S, E)
tgt = torch.randn(N, T, E)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(
T,
device=tgt.device,
)
out = transformer(
src,
tgt,
tgt_mask=tgt_mask,
)
print(out.shape)
# [N, T, E]
# [32, 20, 512]压缩版
class Transformer(nn.Module):
def __init__(self):
self.encoder = TransformerEncoder(
TransformerEncoderLayer(...),
num_encoder_layers,
encoder_norm,
)
self.decoder = TransformerDecoder(
TransformerDecoderLayer(...),
num_decoder_layers,
decoder_norm,
)
def forward(self, src, tgt):
memory = self.encoder(src)
output = self.decoder(tgt, memory)
return outputclass TransformerEncoderLayer(nn.Module):
def forward(self, src):
src = src + SelfAttention(src)
src = src + FFN(src)
return srcclass TransformerDecoderLayer(nn.Module):
def forward(self, tgt, memory):
tgt = Norm(tgt + SelfAttention(tgt))
tgt = Norm(tgt + CrossAttention(tgt, memory))
tgt = Norm(tgt + FFN(tgt))
return tgt