Skip to content

MLM

屏蔽语言模型

在屏蔽语言建模中,我们通常屏蔽给定句子中特定百分比的单词,模型期望基于该句子中的其他单词预测这些被屏蔽的单词。这样的训练方案使这个模型在本质上是双向的,因为掩蔽词的表示是根据出现的词来学习的,不管是左还是右。

MLM 的想法和 CLM 很不一样。它不是让模型“往后接”,而是让模型“把缺的词补出来”。

BERT 对 MLM 的定义很直接:随机把输入里一部分 token mask 掉,然后让模型根据上下文预测这些被 mask 的原始 token。论文还特别强调,这样做的结果是,模型在所有层都能同时融合左、右上下文,从而得到深度双向表示。

举个最简单的例子。原句是:

我 今天 去 北京

如果把第三个位置遮掉,就变成:

我 今天 [MASK] 北京

模型要预测:

被遮掉的那个位置原来是“去”

这时候它不是只看左边“我 今天”,也会看右边“北京”。所以 MLM 学到的是:

给定上下文,恢复被挖掉的词。

这就是为什么大家常把 MLM 叫作“完形填空”目标。

公式

还是设原始序列是

x=(x1,x2,,xT).

先随机采样一个被遮盖的位置集合 M1,,T,再把原句变成一个“损坏后的输入” x~。这个损坏后的输入,不一定所有被选中的位置都真替换成 [MASK]

那么 MLM 的训练目标可以写成:

LMLM(x)=EMtMlogpθ(xtx~).

这条式子的意思很简单:

  • 先随机选一些位置出来
  • 把这些位置做遮盖/扰动
  • 只在这些被选中的位置上计算预测损失
  • 要预测的是“原来真实的 token 是什么”

这里用 x~ 而不是 xt,是因为实际训练里通常会同时遮掉多个位置,所以某个位置在预测时,周围的一些 token 也可能已经被遮掉了。这个写法更准确。BERT 论文的原话是:mask token 对应的最终隐向量会送进一个词表 softmax,来预测原始词 id。

如果把 softmax 也写出来,形式和 CLM 很像。对于一个被选中的位置 tM,若模型输出 logits zt,那么

pθ(xt=vx~)=exp(zt,v)uVexp(zt,u),

对应位置的 loss 是

tMLM=logpθ(xtx~)logexp(zt,xt)uVexp(zt,u).

整句的 MLM loss 就是把所有被 mask 的位置加起来,或者再除以 |M| 做平均。 所以从形式上说,MLM 也是 cross-entropy;但和 CLM 的区别是:

  • CLM 在几乎每个位置都预测
  • MLM 只在被选中的那些位置上预测

BERT 里的 15% / 80%-10%-10%

BERT 论文里说得很清楚:它先随机选出 15% 的 token 位置作为预测目标;然后对这些被选中的位置,不是永远都换成 [MASK],而是采用下面的混合策略:

  • 80% 的时候,真的替换成 [MASK]
  • 10% 的时候,替换成一个随机 token
  • 10% 的时候,保持原 token 不变

然后这些被选中的位置,都要用交叉熵去预测原始 token。论文给出的原因也很直接:如果总是用 [MASK],那就会产生一个 pre-train / fine-tune mismatch,因为 [MASK] 这个符号在下游微调时通常不会出现。

这套 80/10/10 的直觉可以这样理解:

如果 100% 都替成 [MASK],模型会过于依赖“这里有个特殊符号,说明我要猜词了”; 加入 10% 随机替换,是让模型不能完全信任表面输入; 再加入 10% 保持不变,是让模型即使看到原词,也得学会利用上下文,而不是只把“被换成 [MASK] 的位置”当特例。

还有一个很重要的后果是:在 vanilla BERT 的 MLM 里,每条序列只有大约 15% 的位置直接参与预测损失,所以它的监督信号比 CLM 稀疏。BERT 自己也在附录里提到,MLM 的收敛会比 LTR LM 稍慢一些,但绝对效果很快超过了 LTR。

8. 为什么 MLM 更适合理解,不太适合原生自由生成

这部分是 CLM 和 MLM 最本质的差别。

CLM 学的是:

p(x)=tp(xtx<t),

所以它直接给出了一个规范的联合概率分解。这意味着它既能打分整句,也能自然地从左到右生成。

MLM 学的则是:

p(xtx~)

也就是“被挖空位置在上下文下的条件概率”。这对“理解上下文、补全缺词”很有帮助,但它并没有直接给出一个标准的左到右联合分布。所以,原生 MLM 并不像 CLM 那样天然适合一步一步自由生成整句。这个差别并不是实现细节,而是目标函数本身决定的。BERT 论文把 MLM 明确地和 left-to-right LM 区分开;而关于如何用 MLM 给句子打分,后续工作通常会用 pseudo-log-likelihood 这样的替代量。

最常见的 MLM 句子打分方式叫 PLL(pseudo-log-likelihood)

PLL(x)=t=1Tlogpθ(xtxt),

这里 xt 表示“把第 t 个词遮掉,其余词保留”。直觉上就是:把句子里的每个 token 轮流挖掉一次,再让 MLM 去猜它,然后把所有位置的 log-prob 加起来。对应的 pseudo-perplexity

PPPL(x)=exp(1TPLL(x)).

Salazar 等人的论文专门把 conventional LM 的 chain-rule log probability 和 MLM 的 PLL 区分开来:前者是标准联合概率的对数,后者是 MLM 的替代性评分办法。

所以一句话讲清楚就是:

CLM 学的是“怎么把句子接下去”;MLM 学的是“怎么看懂句子里缺了什么”。