前言
目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer。RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用self-attention机制来学习它们之间的依赖关系。这两种架构目前来看都取得了令人瞩目的成就,但它们都局限在捕捉长期依赖性上。
为了解决这一问题,CMU联合Google Brain在2019年1月推出的一篇新论文《Transformer-XL:Attentive Language Models beyond a Fixed-Length Context》同时结合了RNN序列建模和Transformer自注意力机制的优点,在输入数据的每个段上使用Transformer的注意力模块,并使用循环机制来学习连续段之间的依赖关系。Transformer-XL在多种语言建模数据集(如单词级别的enwik8和字符级别的text8)上实现了目前的SoTA效果,且该模型在推理阶段速度更快,比之前最先进的利用Transformer进行语言建模的方法快300~1800倍。 同时,该论文也放出了其配套源码(包括TensorFlow和PyTorch的)、预训练模型及在各个数据集上训练的超参数,可以说是非常良心了~造福我等伸手党!
本文将主要针对模型原理及其PyTorch实现进行逐一对照解读,因笔者能力有限,如有不详尽之处,可移步文末的传送门进行详细阅读,并欢迎指出~
文章目录
- 前言
- 一. 回顾Transformer
- 二. vanilla Transformer
- 三. Transformer-XL
- 1. 引入循环机制
- 2. 相对位置编码
- 3. 整体计算公式
- 四. PyTorch实现
- 五. 实验结果
- 1. 语言建模指标
- 2. 两个创新点的优势
- 3. 测试阶段的速度
- 六. 总结
- 1. 模型特点
- 2. 优点
- 3. 不足
- 传送门
一. 回顾Transformer
在NLP领域中,一种对语言建模的最常用模型就是RNN,它可以捕捉单词之间的依赖关系。但因为梯度消失和爆炸的问题,RNN变得非常难以训练,LSTM单元和梯度裁剪方法的提出也不足以解决此类问题。同时RNN网络的计算速度往往很慢,其学习长期依赖的能力也较为有限(论文中提到,LSTM语言模型平均只能建模200个上下文词语)。
2017年6月,Google Brain在论文《Attention Is All You Need》中提出的Transformer架构,完全摒弃了RNN的循环机制,采用一种self-attention的方式进行全局处理。其接收一整段序列,并使用三个可训练的权重矩阵——Query、Key和Value来一次性学习输入序列中各个部分之间的依赖关系。Transformer网络由多个层组成,每个层都由多头注意力机制和前馈网络构成。由于在全局进行注意力机制的计算,忽略了序列中最重要的位置信息。Transformer为输入添加了位置编码(Positional Encoding),使用正弦函数完成,为每个部分的位置生成位置向量,不需要学习,用于帮助网络学习其位置信息。其示意如下图所示:
有关Transformer的更深入讨论,可参考笔者之前的博客:
Transformer(论文 + PyTorch源码解读)
二. vanilla Transformer
为何要提这个模型?因为Transformer-XL是基于这个模型进行的改进。
Al-Rfou等人基于Transformer提出了一种训练语言模型的方法( https://arxiv.org/abs/1808.04444 ),来根据之前的字符预测片段中的下一个字符。例如,它使用 x 1 , x 2 , . . . , x n − 1 x 1 , x 2 , . . . , x n − 1 x 1 , x 2 , . . . , x n − 1 x1,x2,...,xn−1x_1, x_2, ..., x_{n-1}x1,x2,...,xn−1 x1,x2,...,xn−1x1,x2,...,xn−1x1,x2,...,xn−1预测字符 x n x n x n xnx_nxn xnxnxn,而在 x n x n x n xnx_nxn xnxnxn之后的序列则被mask掉。论文中使用64层模型,并仅限于处理 512个字符这种相对较短的输入,因此它将输入分成段,并分别从每个段中进行学习,如下图所示。 在测试阶段如需处理较长的输入,该模型会在每一步中将输入向右移动一个字符,以此实现对单个字符的预测。
该模型在常用的数据集如enwik8和text8上的表现比RNN模型要好,但它仍有以下两个缺点:
a. 上下文长度受限:字符之间的最大依赖距离受输入长度的限制,模型看不到出现在几个句子之前的单词。
b. 上下文碎片:对于长度超过512个字符的文本,都是从头开始单独训练的。段与段之间没有上下文依赖性,会让训练效率低下,也会影响模型的性能。
c. 推理速度慢:在测试阶段,每次预测下一个单词,都需要重新构建一遍上下文,并从头开始计算,这样的计算速度非常慢。
三. Transformer-XL
Transformer-XL架构在vanilla Transformer的基础上引入了两点创新:循环机制(Recurrence Mechanism)和相对位置编码(Relative Positional Encoding),以克服vanilla Transformer的缺点。与vanilla Transformer相比,Transformer-XL的另一个优势是它可以被用于单词级和字符级的语言建模。
1. 引入循环机制
与vanilla Transformer的基本思路一样,Transformer-XL仍然是使用分段的方式进行建模,但其与vanilla Transformer的本质不同是在于引入了段与段之间的循环机制,使得当前段在建模的时候能够利用之前段的信息来实现长期依赖性。如下图所示:
在训练阶段,处理后面的段时,每个隐藏层都会接收两个输入:
- 该段的前面隐藏层的输出,与vanilla Transformer相同(上图的灰色线)。
- 前面段的隐藏层的输出(上图的绿色线),可以使模型创建长期依赖关系。
这两个输入会被拼接,然后用于计算当前段的Key和Value矩阵。对于某个段的某一层的具体计算公式如下:
其中, τ τ τ τ\tauτ τττ表示第几段, n n n nnn nnn表示第几层, h h h hhh hhh表示隐层的输出。 S G ( ⋅ ) S G ( ⋅ ) S G ( ⋅ ) SG(⋅)SG(·)SG(⋅) SG(⋅)SG(⋅)SG(⋅)表示停止计算梯度, [ h u ∘ h v ] [ h u ∘ h v ] [ h u ∘ h v ] [hu∘hv][h_u \circ h_v][hu∘hv] [hu∘hv][hu∘hv][hu∘hv]表示在长度维度上的两个隐层的拼接, W . W . W . W.W_.W. W.W.W.是模型参数。乍一看与Transformer中的计算公式很像,唯一关键的不同就在于Key和Value矩阵的计算上,即 k τ + 1 n k τ + 1 n k τ + 1 n kτ+1nk_{\tau+1}^nkτ+1n kτ+1nkτ+1nkτ+1n和 v τ + 1 n v τ + 1 n v τ + 1 n vτ+1nv_{\tau + 1}^nvτ+1n vτ+1nvτ+1nvτ+1n,它们基于的是扩展后的上下文隐层状态 h τ + 1 n − 1 h ~ τ + 1 n − 1 h τ + 1 n − 1 h~τ+1n−1\tilde{h}_{\tau+1}^{n-1}h~τ+1n−1 h τ+1n−1h~τ+1n−1h τ+1n−1进行计算, h τ n − 1 h τ n − 1 h τ n − 1 hτn−1{h}_{\tau}^{n-1}hτn−1 hτn−1hτn−1hτn−1是之前段的缓存。
原则上只要GPU内存允许,该方法可以利用前面更多段的信息,测试阶段也可以获得更长的依赖。
在测试阶段,与vanilla Transformer相比,其速度也会更快。在vanilla Transformer中,一次只能前进一个step,并且需要重新构建段,并全部从头开始计算;而在Transformer-XL中,每次可以前进一整个段,并利用之前段的数据来预测当前段的输出。
2. 相对位置编码
在Transformer中,一个重要的地方在于其考虑了序列的位置信息。在分段的情况下,如果仅仅对于每个段仍直接使用Transformer中的位置编码,即每个不同段在同一个位置上的表示使用相同的位置编码,就会出现问题。比如,第 i − 2 i − 2 i − 2 i−2i-2i−2 i−2i−2i−2段和第 i − 1 i − 1 i − 1 i−1i-1i−1 i−1i−1i−1段的第一个位置将具有相同的位置编码,但它们对于第 i i i iii iii段的建模重要性显然并不相同(例如第 i − 2 i − 2 i − 2 i−2i-2i−2 i−2i−2i−2段中的第一个位置重要性可能要低一些)。因此,需要对这种位置进行区分。
论文对于这个问题,提出了一种新的位置编码的方式,即会根据词之间的相对距离而非像Transformer中的绝对位置进行编码。在Transformer中,第一层的计算查询 q i T q i T q i T qiTq_i^TqiT qiTqiTqiT和键 k j k j k j kjk_jkj kjkjkj之间的attention分数的方式为:
其中, E x i E x i E x i ExiE_{x_i}Exi ExiExiExi是词 i i i iii iii的embedding, E x j E x j E x j ExjE_{x_j}Exj ExjExjExj是词 j j j jjj jjj的embedding, U i U i U i UiU_iUi UiUiUi和 U j U j U j UjU_jUj UjUjUj是位置向量,这个式子实际上是 ( W q ( E x i + U i ) ) T ⋅ ( W k ( E x j + U j ) ) ( W q ( E x i + U i ) ) T ⋅ ( W k ( E x j + U j ) ) ( W q ( E x i + U i ) ) T ⋅ ( W k ( E x j + U j ) ) (Wq(Exi+Ui))T⋅(Wk(Exj+Uj))(W_q(E_{x_i}+U_i))^T·(W_k(E_{x_j}+U_j))(Wq(Exi+Ui))T⋅(Wk(Exj+Uj)) (Wq(Exi+Ui))T⋅(Wk(Exj+Uj))(Wq(Exi+Ui))T⋅(Wk(Exj+Uj))(Wq(Exi+Ui))T⋅(Wk(Exj+Uj))的展开,就是Transformer中的标准格式。
在Transformer-XL中,对上述的attention计算方式进行了变换,转为相对位置的计算,而且不仅仅在第一层这么计算,在每一层都是这样计算。
对比来看,主要有三点变化:
- 在(b)和(d)这两项中,将所有绝对位置向量 U j U j U j UjU_jUj UjUjUj都转为相对位置向量 R i − j R i − j R i − j Ri−jR_{i-j}Ri−j Ri−jRi−jRi−j,与Transformer一样,这是一个固定的编码向量,不需要学习。
- 在(c)这一项中,将查询的 U i T W q T U i T W q T U i T W q T UiTWqTU_i^TW_q^TUiTWqT UiTWqTUiTWqTUiTWqT向量转为一个需要学习的参数向量 u u u uuu uuu,因为在考虑相对位置的时候,不需要查询的绝对位置 i i i iii iii,因此对于任意的 i i i iii iii,都可以采用同样的向量。同理,在(d)这一项中,也将查询的 U i T W q T U i T W q T U i T W q T UiTWqTU_i^TW_q^TUiTWqT UiTWqTUiTWqTUiTWqT向量转为另一个需要学习的参数向量 v v v vvv vvv。
- 将键的权重变换矩阵 W k W k W k WkW_kWk WkWkWk转为 W k , E W k , E W k , E Wk,EW_{k, E}Wk,E Wk,EWk,EWk,E和 W k , R W k , R W k , R Wk,RW_{k, R}Wk,R Wk,RWk,RWk,R,分别作为content-based key vectors和location-based key vectors。
从另一个角度来解读这个公式的话,可以将attention的计算分为如下四个部分:
a. 基于内容的“寻址”,即没有添加原始位置编码的原始分数。
b. 基于内容的位置偏置,即相对于当前内容的位置偏差。
c. 全局的内容偏置,用于衡量key的重要性。
d. 全局的位置偏置,根据query和key之间的距离调整重要性。
3. 整体计算公式
结合上面两个创新点,将Transformer-XL模型的整体计算公式整理如下,这里考虑一个N层的只有一个注意力头的模型:
其中, τ τ τ τ\tauτ τττ代表第几段, n n n nnn nnn代表第几层, h τ 0 : = E s τ h τ 0 : = E s τ h τ 0 : = E s τ hτ0:=Esτh_\tau^0 := E_{s_\tau}hτ0:=Esτ hτ0:=Esτhτ0:=Esτhτ0:=Esτ定义为第 τ τ τ τ\tauτ τττ段的词向量序列。值得一提的是,计算 A A A AAA AAA矩阵的时候,需要对所有的 i − j i − j i − j i−ji-ji−j i−ji−ji−j计算 W k , R n R i − j W k , R n R i − j W k , R n R i − j Wk,RnRi−jW_{k,R}^nR_{i-j}Wk,RnRi−j Wk,RnRi−jWk,RnRi−jWk,RnRi−j,如果直接按照公式计算的话,计算时间是 O ( l e n g t h ) 2 O ( l e n g t h ) 2 O ( l e n g t h ) 2 O(length)2O(length)^2O(length)2 O(length)2O(length)2O(length)2,而实际上 i − j i − j i − j i−ji-ji−j i−ji−ji−j的范围只从0 ~ length,因此可以先计算好这length个向量,然后在实际计算 A A A AAA AAA矩阵时直接取用即可。
具体的,设 M M M MMM MMM和 L L L LLL LLL分别为memory和当前段序列的长度,则 i − j i − j i − j i−ji-ji−j i−ji−ji−j的范围也就为0 ~ M + L − 1 M + L − 1 M + L − 1 M+L−1M + L - 1M+L−1 M+L−1M+L−1M+L−1。下面的 Q Q Q QQQ QQQ矩阵中的每一行都代表着 W k , R R i − j W k , R R i − j W k , R R i − j Wk,RRi−jW_{k,R}R_{i-j}Wk,RRi−j Wk,RRi−jWk,RRi−jWk,RRi−j中一个 i − j i − j i − j i−ji-ji−j i−ji−ji−j的可能性,即 Q k = W k , R R M + L − 1 − k Q k = W k , R R M + L − 1 − k Q k = W k , R R M + L − 1 − k Qk=Wk,RRM+L−1−kQ_k = W_{k, R} R_{M+L-1-k}Qk=Wk,RRM+L−1−k Qk=Wk,RRM+L−1−kQk=Wk,RRM+L−1−kQk=Wk,RRM+L−1−k。
则对于上面公式中的(b)项,即 q i T W k , R R i − j q i T W k , R R i − j q i T W k , R R i − j qiTWk,RRi−jq_i^TW_{k,R}R_{i-j}qiTWk,RRi−j qiTWk,RRi−jqiTWk,RRi−jqiTWk,RRi−j,其构成的所有可能向量的矩阵为 B B B BBB BBB矩阵,其形状为 L ∗ ( M + L ) L ∗ ( M + L ) L ∗ ( M + L ) L∗(M+L)L * (M + L)L∗(M+L) L∗(M+L)L∗(M+L)L∗(M+L),这是我们最终需要的(b)项的attention结果。
我们进一步定义 B B ~ B B~\tilde{B}B~ B B~B 矩阵为如下:
可见,需要的 B B B BBB BBB矩阵的每一行只是 B B ~ B B~\tilde{B}B~ B B~B 的向左shift而已。因此,可以直接利用矩阵乘法计算 B B ~ B B~\tilde{B}B~ B B~B 即可。设 R i − j R i − j R i − j Ri−jR_{i-j}Ri−j Ri−jRi−jRi−j的维度为 d R d R d R dRd_RdR dRdRdR, q i q i q i qiq_iqi qiqiqi的维度为 d q d q d q dqd_qdq dqdqdq, W k , R W k , R W k , R Wk,RW_{k,R}Wk,R Wk,RWk,RWk,R矩阵的维度为 d q ∗ d R d q ∗ d R d q ∗ d R dq∗dRd_q * d_Rdq∗dR dq∗dRdq∗dRdq∗dR,则直接计算矩阵B的时间复杂度为 2 ∗ d q ∗ d R ∗ L ∗ ( M + L ) 2 ∗ d q ∗ d R ∗ L ∗ ( M + L ) 2 ∗ d q ∗ d R ∗ L ∗ ( M + L ) 2∗dq∗dR∗L∗(M+L)2* d_q * d_R * L * (M+L)2∗dq∗dR∗L∗(M+L) 2∗dq∗dR∗L∗(M+L)2∗dq∗dR∗L∗(M+L)2∗dq∗dR∗L∗(M+L),而计算 B B ~ B B~\tilde{B}B~ B B~B 的时间复杂度为 L ∗ d q ∗ ( M + L ) + d q ∗ d R ∗ ( M + L ) L ∗ d q ∗ ( M + L ) + d q ∗ d R ∗ ( M + L ) L ∗ d q ∗ ( M + L ) + d q ∗ d R ∗ ( M + L ) L∗dq∗(M+L)+dq∗dR∗(M+L)L * d_q * (M + L) + d_q * d_R * (M + L)L∗dq∗(M+L)+dq∗dR∗(M+L) L∗dq∗(M+L)+dq∗dR∗(M+L)L∗dq∗(M+L)+dq∗dR∗(M+L)L∗dq∗(M+L)+dq∗dR∗(M+L),计算量明显不是一个量级(后者要快很多)。
同理,对于(d)项来说,可以对所有的 i − j i − j i − j i−ji-ji−j i−ji−ji−j定义需要的矩阵 D D D DDD DDD为 L ∗ ( M + L ) L ∗ ( M + L ) L ∗ ( M + L ) L∗(M+L)L * (M+L)L∗(M+L) L∗(M+L)L∗(M+L)L∗(M+L):
可以用如下的 d d ~ d d~\tilde{d}d~ d d~d 来进行shift得到:
其中 Q Q Q QQQ QQQ矩阵已经计算过了,也可以在这一步减少计算量。
四. PyTorch实现
笔者在这里主要研究的是核心模型部分,将针对关键的实现细节进行剖析,想要看完整代码的读者请戳这里。
- 首先来看RelativePositionalEmbedding部分。
class PositionalEmbedding(nn.Module): def __init__(self, demb): super(PositionalEmbedding, self).__init__() self.demb = demb inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
def forward(self, pos_seq): sinusoid_inp = torch.ger(pos_seq, self.inv_freq) pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) return pos_emb[:,None,:]