Transformer最大的问题在于没有办法建模超过最大长度的序列,例如base bert其支持的序列最大长度是512,超过了该长度的序列需要进行截取,再把截取后的片段分别用bert进行编码,该方法虽然可行,但是存在上下文碎片化的问题,也就是说每个片段是单独建模的,互相之间没有上下文信息,并且,不同的片段位置编码都是从0开始,明显是有问题的。
可见Transformer对于较长的序列建模能力有限,如何解决该弊端就该Transformer-XL大显身手了。
Transformer-XL主要提出了两个优化点
接下来我们分别看下两个优化点是如何做的
在讲解第一个优化点之前,我们简单回顾下vanilla transformer,在训练阶段如果要对多个片段编码,其训练过程如下图,可以看到,两个片段没有相互依赖,上下文信息会丢失,不同的片段位置编码一样,因此也不准确。
再看下inference阶段,对于第一个segment,预测和vanilla版本一样的,跨段预测时(大于第一个片段的序列),由于依赖的上下文长度是固定的,可以理解为使用了一个滑动窗口,每次窗口的值都不一样,所以每次只能预测一个字/词,并且每次都要完整的计算,例如下图中,每个segment长度是4,超过4的部分只能逐字/词计算。
为了解决固定长度的限制,Transformer-XL提出了一种递归机制,如下图,第一个segment计算完成后,把计算的结果保存下来,在计算第二个片段的时候,把第一个片段的hidden state和第二个片段的hidden state拼接在一起,再进行后续的计算。
我们看下具体的计算公式,其中h表示的是hidden state, τ \tau τ表示第 τ \tau τ个segment,SG函数表示的是不更新梯度,[]表示的是向量的拼接,第一个公式的意思即:第 τ + 1 \tau +1 τ+1个segment第n-1层的hidden state 等于第 τ \tau τ个segment第n-1层的hidden state拼接上第 τ + 1 \tau +1 τ+1个segment第n-1层的hidden state,后续两个公式和vanilla版本类似,但要注意,q是未拼接的hidden state,k、v是拼接过后的,因为q表示的是当前的segment,所以不需要拼接。
可以看到,对于第一个segment来说,hidden state是没有额外需要拼接的值的,从第二个segment开始才需要拼接,在论文中,每次都是和上一个segment进行拼接,理论上来说每次可以拼接多个segment,第n个segment可以和前n-1个segment进行拼接,不过这个就取决于你自己的显存了,并且一个segment通常来说不会像上图中的这么短(一个segment可能长度就512了),文本自身的上下文依赖一般也不会超过一个segment的长度。
再看下inference阶段,大于第一个segment的序列,均可以进行批计算,每个批的长度是segment的长度,并且,由于每次都会保存前一个segment的hidden state,所以不需要像vanilla版本重新计算。论文中对比了一下,Transformer-XL在enwiki8数据集上的inference速度是Vanilla Transformer的1800+倍
接下来我们来看第二个优化点,相对位置编码。Vanilla Transformer使用的是绝对位置编码,其计算方式如下,pos表示的是token的下标, d m o d e l d_{model} dmodel表示的是hidden size,i表示的是具体的某个维度。
可见,不同的片段的同一个位置其位置编码都是一样的,模型没办法正确区分不同片段的位置信息,我们再看下Transformer-XL的位置编码是怎么做的。
Vanilla的位置编码是和embedding相加后输入到下一层的,Transformer-XL的位置编码没有在输入上做处理,而是对attention score进行了修改,先回顾下vanilla版本attention score的计算
A a b s = Q W q K W k A^{abs}=QW_q KW_k Aabs=QWqKWk
把Q和K展开,E表示embedding,U表示位置编码
A a b s = ( E q + U q ) W q ( E k + U k ) W k = ( E q W q + U q W q ) ( E k W k + U k W k ) = E q W q E k W k + E q W q U k W k + U q W q E k W k + U q W q U k W k \begin{aligned} A^{abs}&=(E_q+U_q)W_q (E_k+U_k)W_k \\ &=(E_qW_q+U_qW_q)(E_kW_k+U_kW_k) \\ &=E_qW_qE_kW_k + E_qW_qU_kW_k+U_qW_qE_kW_k+U_qW_qU_kW_k \end{aligned} Aabs=(Eq+Uq)Wq(Ek+Uk)Wk=(EqWq+UqWq)(EkWk+UkWk)=EqWqEkWk+EqWqUkWk+UqWqEkWk+UqWqUkWk
即论文中下图的公式
考虑一下,当query与key进行计算时,实际上并不需要知道key的绝对位置编码,模型实际上需要的是一个“时间线索”即字词的一个先后顺序,因此,知道query与key的相对位置即可。根据以上的思路,Transformer-XL做了三个方面的改进,分别如下
在新的参数下,每一项都有了一个具体的含义,a表示的是query与key的内容相关性,b表示的是query的内容和key的位置的相关性,c表示的是query的位置与key的内容的相关性,d表示的是quey与key的位置的相关性
总结一下,对于一个N层1个head的Transformer-XL,其完整步骤如下
除此之外,论文中对b与d的计算做了一定的优化, R i − j R_{i-j} Ri−j需要分别计算i与j的值,时间复杂度是 O ( n 2 ) O(n^2) O(n2),优化后能达到 O ( n ) O(n) O(n)
首先定义一个Q矩阵,表示相对位置编码,注意,R是反着来的从 M + L − 1 M+L−1 M+L−1到0
把b的结果展开,实际上是一个 L × ( M + L ) L × (M + L) L×(M+L)的matrix,其中L表示segment的长度,M表示memory的长,结果为
如果我们定义一个 B ~ = q Q T \widetilde B=qQ^T B =qQT则有
对比下 B B B与 B ~ \widetilde B B ,第i行的 B B B实际上就是第i行的 B ~ \widetilde B B 进行了左移,因此,计算 B B B只需要先计算出 B ~ \widetilde B B 然后按行左移。同理,d也可按照相同的方法进行计算
Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context