机构:斯坦福、Google Brain
作者:Kevin Clark、Minh-Thang Luong、Quoc V. Le
论文地址:https://arxiv.org/abs/2003.10555
收录会议:ICLR 2020
论文代码:https://github.com/google-research/electra
MLM(Masked language modeling )方式的预训练语言模型如BERT是在输入上用[MASK]遮蔽掉部分tokens,再训练一个模型以重建出原始的tokens。这种方式迁移到下游NLP任务时能够得到较好的结果,已然成为NLP任务中的标配。但是这种预训练的方法往往需要大量的算力。为此,本文提出一种更具样本效率的预训练任务——替换token检测(RTD)。RTD不对输入进行遮蔽,而是从生成网络中采样得到可信的tokens,再替换掉原始输入上的tokens。本文这里不再训练模型预测遭到破坏的输入中的tokens,而是训练一个判别器以辨别此时的输入中的每个token是否是生成器生成的。通过实验表明这种新的预训练任务比MLM更高效,这是由于该任务定义于全部的输入tokens,而非仅仅被遮蔽掉的那一部分小小的输入子集。实验结果表明本文方案所学习到的上下文表征大大优于相同模型大小、相同数据量和相同算力下的BERT。Small版的模型上收益更加显著,比如在GLUE自然语言理解基准上单块GPU上训练4天结果优于GPT(算力比其高30倍+)。本文的方法在规模上也是很有效,只使用RoBERTa和XLNet四分之一的算力即可得到可以比肩的结果,在使用相同算力的下可以超越对方。
本文方法需要训练两个神经网络分别是生成器 G G G和判别器 D D D。两者本质上是由encoder组成,比如Transformer。encoder将输入token序列 x = [ x 1 , … , x n ] \boldsymbol{x}=\left[x_{1}, \ldots, x_{n}\right] x=[x1,…,xn]映射为一序列上下文表征向量 h ( x ) = [ h 1 , … , h n ] h(\boldsymbol{x})=\left[h_{1}, \ldots, h_{n}\right] h(x)=[h1,…,hn]。给定位置 t t t,生成器用softmax层输出 x t x_t xt的概率分布:
p G ( x t ∣ x ) = exp ( e ( x t ) T h G ( x ) t ) / ∑ x ′ exp ( e ( x ′ ) T h G ( x ) t ) p_{G}\left(x_{t} | \boldsymbol{x}\right)=\exp \left(e\left(x_{t}\right)^{T} h_{G}(\boldsymbol{x})_{t}\right) / \sum_{x^{\prime}} \exp \left(e\left(x^{\prime}\right)^{T} h_{G}(\boldsymbol{x})_{t}\right) pG(xt∣x)=exp(e(xt)ThG(x)t)/x′∑exp(e(x′)ThG(x)t)
其中 e e e表示token的嵌入表征。给定位置 t t t,判别器预测token x t x_t xt是否是生成器生成的,其输出层是sigmoid:
D ( x , t ) = sigmoid ( w T h D ( x ) t ) D(\boldsymbol{x}, t)=\operatorname{sigmoid}\left(w^{T} h_{D}(\boldsymbol{x})_{t}\right) D(x,t)=sigmoid(wThD(x)t)
训练的生成器是为了执行MLM。给定输入列 x = [ x 1 , … , x n ] \boldsymbol{x}=\left[x_{1}, \ldots, x_{n}\right] x=[x1,…,xn],MLM先选择一些随机的位置进行遮蔽,遮蔽集合为 m = [ m 1 , … , m k ] \boldsymbol{m}=\left[m_{1}, \ldots, m_{k}\right] m=[m1,…,mk]。这些被选中的tokens被用[MASK] token替换,该操作定义为 x masked = REPLACE ( x , m , [ M A S K ] ) \boldsymbol{x}^{\text {masked }}=\operatorname{REPLACE}(\boldsymbol{x}, \boldsymbol{m},[\mathrm{MASK}]) xmasked =REPLACE(x,m,[MASK])。生成器学习预测出被遮蔽掉的原始token。至于训练的判别器则是为了能够鉴别出token是否是来自生成器。生成器生成的token创建了 x corrupt \boldsymbol{x}^\text {corrupt} xcorrupt,再训练判别器鉴别出 x corrupt \boldsymbol{x}^\text {corrupt} xcorrupt中哪些token是和原始输入 x x x相匹配的。
总结下,模型根据以下式子重建输入:
m i ∼ unif { 1 , n } for i = 1 to k x masked = REPLACE ( x , m , [ M A S K ] ) x ^ i ∼ p G ( x i ∣ x masked ) for i ∈ m x corrupt = REPLACE ( x , m , x ^ ) \begin{aligned} &m_{i} \sim \operatorname{unif}\{1, n\} \text { for } i=1 \text { to } k \quad \boldsymbol{x}^{\text {masked }}=\operatorname{REPLACE}(\boldsymbol{x}, \boldsymbol{m},[\mathrm{MASK}])\\ &\hat{x}_{i} \sim p_{G}\left(x_{i} | \boldsymbol{x}^{\text {masked }}\right) \text { for } i \in \boldsymbol{m} \quad \boldsymbol{x}^{\text {corrupt }}=\operatorname{REPLACE}(\boldsymbol{x}, \boldsymbol{m}, \hat{\boldsymbol{x}}) \end{aligned} mi∼unif{1,n} for i=1 to kxmasked =REPLACE(x,m,[MASK])x^i∼pG(xi∣xmasked ) for i∈mxcorrupt =REPLACE(x,m,x^)
对应的损失函数分别如下:
MLM损失函数:
L M L M ( x , θ G ) = E ( ∑ i ∈ m − log p G ( x i ∣ x m a k k e d ) ) \mathcal{L}_{\mathrm{MLM}}\left(\boldsymbol{x}, \theta_{G}\right)=\mathbb{E}\left(\sum_{i \in m}-\log p_{G}\left(x_{i} | \boldsymbol{x}^{\mathrm{makked}}\right)\right) LMLM(x,θG)=E(i∈m∑−logpG(xi∣xmakked))
判别器损失函数:
L D i s c ( x , θ D ) = E ( ∑ t = 1 n − 1 ( x t corrupt = x t ) log D ( x corrupt , t ) − 1 ( x t cortupt ≠ x t ) log ( 1 − D ( x cornurt , t ) ) ) \mathcal{L}_{\mathrm{Disc}}\left(\boldsymbol{x}, \theta_{D}\right)=\mathbb{E}\left(\sum_{t=1}^{n}-\mathbb{1}\left(x_{t}^{\text {corrupt }}=x_{t}\right) \log D\left(\boldsymbol{x}^{\text {corrupt }}, t\right)-\mathbb{1}\left(x_{t}^{\text {cortupt }} \neq x_{t}\right) \log \left(1-D\left(\boldsymbol{x}^{\text {cornurt }}, t\right)\right)\right) LDisc(x,θD)=E(t=1∑n−1(xtcorrupt =xt)logD(xcorrupt ,t)−1(xtcortupt =xt)log(1−D(xcornurt ,t)))
尽管看似GAN的训练目标函数,但是存在以下几点关键不同点:
第一,当生成器碰巧生成正确的token,则该token被视为真实的,而非假冒的。实验发现这种方式可以一定程度上地改善下游任务。
第二,生成器的训练是基于最大似然而非对抗训练。后者企图在训练过程中骗过判别器。对抗训练在这里是很难实现,究其原因在于无法通过从生成器采样来进行反向传播。尽管本文也尝试使用强化学习的方式训练生成器从而绕开这个问题(附件F),但是效果仍不及最大似然的训练方式。
第三,不像GAN那样在生成器的输入中添加噪声
最终的联合损失:
min θ G , θ D ∑ x ∈ X L M L M ( x , θ G ) + λ L D i s c ( x , θ D ) \min _{\theta_{G}, \theta_{D}} \sum_{\boldsymbol{x} \in \mathcal{X}} \mathcal{L}_{\mathrm{MLM}}\left(\boldsymbol{x}, \theta_{G}\right)+\lambda \mathcal{L}_{\mathrm{Disc}}\left(\boldsymbol{x}, \theta_{D}\right) θG,θDminx∈X∑LMLM(x,θG)+λLDisc(x,θD)
其中 X \mathcal{X} X表示原始语料。实验过程中用单样本来近似损失的期望。同时不把判别器的损失反向传播到生成器,PS:由于采样的步骤也做不到。在预训练之后,丢弃生成器,并在下游任务中对判别器进行微调。
为验证文本方法的有效性,在GLUE和SQuAD上进行实验。预训练数据集分两种情况:第一种是和BERT对比,则是采用Wikipedia 和 BooksCorpus数据集,合计大概3.3 Billion tokens;第二种是和XLNet对比,则与其相同的数据集进行预训练,即ClueWeb、CommonCrawl和Gigaword,合计33 Billion tokens。
本文的模型架构和超参数总体上和BERT相同。在微调阶段,对于GLUE则在ELECTRA的顶部增加一个简单的线性分类器;对于SQuAD则借用XLNet中的问-答模块,将其加到ELECTRA的顶部。之所以这部分借用XLNet,是由于该模块比BERT更精细,起始位置和终点位置是联合预测的,而不像BERT那样两者各自独立,另外对于SQuAD2.0该模块还有能否回答的分类器。需要注意的是,一些评估数据集很小,这意味着微调模型的准确性可能会随着随机种子的不同而发生很大的变化。因此,最终报告结果是从相同的预先训练的检查点运行10次微调的中位数。
为了改善模型本文在模型中进行了以下几点拓展。除非特别说明,否则实验对比中模型大小和训练数据集均与BERT-Base相同。
第一,权重共享。
生成器和判别器权重参数共享可以提高预训练的效率。当生成器和判别器大小相同,则二者的所有transformer权重都可以密切相连。但是,实验中发现较小的生成器更为高效,这使得只能共享生成器和判别器的嵌入参数(包括token嵌入和位置嵌入)。这种情况下,使用的嵌入大小是根据判别器隐状态的嵌入尺寸。在生成器中添加一个线性变换层以将嵌入投射为生成器隐状态尺寸的表征。
实验对比了,生成器和判别器相同大小的情况。在GLUE中,权重完全无关时得分为83.6,token 嵌入共享时得分为84.3,所有权重共享时得分为84.4。可以猜想本文的模型得益于token嵌入的共享方案,可能的原因是MLM在学习以下表征上尤为有效:
基于上述这些发现,最终选用嵌入参数共享。
第二,较小的生成器。
如果生成器和判别器大小相同,则训练ELECTRA每步需要的计算量大约是只使用MLM训练的两倍。为此,本文使用较小的生成器以降低该因素的影响。具体是降低层的大小,同时保持其他超参数恒定不变。此外,本文还尝试了一个极简的unigram生成器,该生成器根据训练语料中的频率采样生成假的tokens。不同大小的生成器和判别器在GLUE上的得分如Figure 3中左图所示。
从中发现生成器最佳的尺寸是判别器尺寸的1/4-1/2。这可能是如果生成器过于强大,那么对应的任务对判别器来说越难。
第三,训练方法。
本文还进一步探索了ELECTRA的其他训练方法。尽管这些探索性工作并没有最终改善结果。探索的训练方法使用以下两阶段的训练流程:
上述提到的权重初始化要求生成器和判别器的大小相同。实验中发现,如果判别器没有这个权重初始化操作,那么有时会甚至在大多数类之外学习失败。这可能是由于此时的生成器已远远领先于判别器。联合训练的方式另一方面天然地为判别器提供了一个总纲,在这个总纲中,生成器一开始很弱,但在整个训练过程中变得越来越好。此外本文还探索了如同GAN那般使用对抗方式训练生成器,具体使用强化学习来适应离散操作的采样生成器。具体参考附件F。
各种训练方法的对比如Figure 3右图所示。从中可以看出,对于两阶段的训练方法,当将生成目标转为判别目标后,下游任务的性能得到显著提升,但是最终没有超过联合训练方案。至于对抗训练方案,尽管超过了BERT,但是仍然逊色于最大似然训练方法。这之间的差距可能源于以下2个因素:
第一,对抗方式训练生成器在MLM上表现更差,只能取得58%的准确率,而最大似然的训练方案可以取得65%的准确率。准确率较差的主要原因是在生成文本的大动作空间中,强化学习的样本效率较低。
第二,对抗方式训练的生成器产生一个低熵输出分布,其中大部分概率集中在单个token上,这意味着生成器生成的样本中没有太多的多样性。
在GLUE dev数据集上各个模型Small版之间的对比详情如Table 1 所示。
可以看出,在模型大小相同时ELECTRA-Small显著优于其他模型。比如ELECTRA-Small高出BERT-Small大概5分,甚至高于更大的GPT模型。ELECTRA-Small的训练多数是收敛的,即使模型的训练时间更短(只有6个小时),仍然可以获得合理的性能结果。
实验结果还表明ELECTRA在中度大小的模型上依然有效,ELECTRA-Base显著优于BERT-Base,甚至超越BERT-Large(GLUE得分为84)。
训练ELECTRA-Large以验证RTD预训练任务在大规模先进的Transformers模型中的有效性。在GLUE dev数据集上的实验结果如Table 2所示。
其中ELECTRA-400K表示400k的训练step。可以看出ELECTRA-400K能够取得与RoBERTa、XLNet相比肩的性能结果,但所需的算力仅仅是RoBERTa和XLNet的1/4。这表明ELECTRA的采样效率增益在大模型上保持不变。训练得越久的ELECTRA,比如ELECTRA-1.75M在大部分的GLUE任务上都优于其他模型且所需的预训练算力更少。令人惊讶的是,基线BERT模型得分明显低于RoBERTa- 100k,这表明我们的模型可能受益于更多的超参数调优或RoBERTa训练数据。ELECTRA在GLUE test数据集上仍能维持收益,具体见Table 3。
实验结果表明在相同算力资源下,ELECTRA的结果得分高于基于MLM的模型。比如ELECTRA-400K优于RoBERTa-100k和BERT基线模型(均使用相近的预训练算力)。ELECTRA训练得越久,结果能够进一步提升:在SQuAD数据集上ELECTRA-1.75M超出此前的其他模型。ELECTRA-Base不仅大大优于相同模型大小的BERT-Base 和XLNet-Base,甚至在多数指标上超越BERT-Large。ELECTRA在SQuAD 2.0上的表现优于SQuAD 1.1。这可能得益于RTD任务的引入,在RTD任务中模型能够辨别token的真伪,这对于有是否可回答类型问题的SQuAD 2.0数据集来说更具迁移能力。
为了进一步理解ELECTRA模型收益来源,对比了以下几种预训练目标:
ELECTRA 15%:
这与ELECTRA大致相同,仅有的不同的在判别器的损失仅来自于输入中被遮蔽的15%tokens。换句话说, L D i s c \mathcal{L}_{\mathrm{Disc}} LDisc中 i i i的不再是1-n,而是 i ∈ m i \in m i∈m。
Replace MLM:
该目标与MLM大致相同,所唯一不同的是不使用[MASK]替换tokens,而是用生成器生成的token替换。这一目标测试了ELECTRA的收益在多大程度上来自于预训练期间解决的将模型暴露于[MASK] tokens的不一致问题,而不是微调阶段。
All-TokensMLM:
与Replace MLM类似,该目标的替换token也是来自生成器。所不同的是,模型预测输入中的所有tokens,而不再仅仅是被遮蔽掉的那些。实验发现当模型有显性的拷贝机制时,效果更好。这种拷贝机制采用一个sigmoid层控制每个输入token的输出拷贝概率为 D D D。模型的输出分布由输入token的D权值,再加上1-D乘以MLM softmax的输出。该模型本质上是BERT和ELECTRA的结合体。需要注意的是,如果没有生成器生成替换字符,模型将很容易地学会从[MASK]标记的词汇表中进行预测,并从输入中直接拷贝其他token。
从实验结果可以看出:
第一,ELECTRA很大程度上受益于基于所有输入token定义的损失。采用15%子集定义的损失逊色于ELECTRA。
第二,从Replace MLM略胜BERT,可以看出BERT预训练时引入的[MASK] token导致微调的不一致,确实略微有损性能指标。
第三,使用All-Tokens MLM方案的模型可以很大程度上缩减BERT和ELECTRA之间的差距。
总之:ELECTRA模型的提升大部分可以归因于从所有tokens中学习,而小部分源于提升得益于减轻预训练-微调的不一致。
另外还对比了BERT和ELECTRA在不同模型尺寸下的表现,如Figure 4所示:
可以看出,小模型中ELECTRA的收益更大,对于完全训练的小模型,ELECTRA也是显著优于BERT。这可以看出ELECTRA训练参数更为高效,估计是由于ELECTRA不必对每个位置可能的token的完整分布进行建模。但是仍然需要进行更多的分析来完全解释ELECTRA的参数效率问题。
本文提出一个新的自监督任务:替换Token检测(RTD)。其核心是训练一个文本encoder以在输入tokens中鉴别出来自于小型生成器所生成的高质量负样本。对比于MLM,本文的预训练目标计算更高效,在下游任务中能取得更优的结果。本文的方法即使在算力相对较少的情况下,仍然有效,这对于算力贫穷人士来说真是大大的福音。