TransformerXL解读

背景

对语言模型建模,RNN和Transformer都是能提取长距离的依赖关系的特征提取器。RNN方面,由于本身的recurrent机制,可以接受任意长度的序列作为输入,但是由于梯度消失和爆炸(gradient vanishing and explosion)和无法并行计算等问题,实际效果不佳;Transformer作为新贵,虽然不存在上述问题,但是由于实际不可能输入任意长度的词encoding到fixed length,只能先按某个固定最大长度分chunks再对每个chunks计算,这就带来了两个问题,即模型无法建立chunks之间的依赖关系(对长文本处理不好)和因为边界问题对开头的几个单词预测的不好(context fragmentation).

对此TransformerXL解决了Vanilla Transformer存在的这些问题(XL的意思的extra long,针对超长文本)。论文参考《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》

Transformer回顾

TransformerXL解读_第1张图片
构建语言模型通常是用t时刻之前的序列预测t时刻词的概率,通过模型将t时刻之前的输入encode到固定长度的representation, 然后经过线性变换后的softmax得到t时刻词的概率。
而基于传统的transformer即vinilla transformer存在的问题主要在于如何将任意长度的sequence编码成固定长度的representation。实际由于算力问题,有一种解决办法是将长序列的文本划分为多个chunks, 每个chunk单独训练模型忽略不同chunk之间的依赖关系。论文见《Character-Level Language Modeling with Deeper Self-Attention》
预测的时候,每次和训练的时候chunk一样输入固定长度的序列,但是只预测下一个token, 之后shift右移一位当作新的chunk,重新计算下一个token,这种方式较为费力。

TransformerXL模型架构

Segment-Level Recurrence with State Reuse

为了解决vinilla transformer的问题,这里提出了循环机制。

During training, the hidden state sequence computed for the previous segment is fixed and cached to be reused as an extended context when the model processes the next new segment。

训练的时候TransformerXL对每个segment的hidden state保留cache作为下一个segment的输入,这样就把不同segment的长距离依赖关系进行捕捉。
TransformerXL解读_第2张图片
其中h表示hiddent state,n表示第n层transformer,t表示第t个segment, SG表示stop gradient,记不算上一个segment的梯度;计算公式可以看出,和vinilla transformerr相比,区别在于计算k和v的时候,是利用上一个segment的hidden state和当前segment的hidden state进行concat之后的结果,这样就能捕捉更长的依赖关系了。由于当前层的hidden state是由下一层的包含当前时刻和前L-1个state计算出来的,依次类推,最长依赖关系正比于O(N × L),N为segment的总个数, L为每个segment的固定长度,通常L>N 。
TransformerXL解读_第3张图片
另外出了上述的好处外,在预测的时候,由于缓存了之前的hidden state,再计算预测之后的token的时候不需要重新计算,比vinilla transformer快上千倍。

Relative Positional Encodings

上述循环机制还有个问题没有解决,就是transformer的position encoding只在segment中由绝对位置编码,却没有跨越segment的相对位置编码,这样模型无法区分不同segment的相同位置的区别。
传统的transformer的attention计算公式为 ( E x i + U i ) W q W k ( E x j + U j ) (E_{x_i}+U_i)W_qW_k(E_{x_j}+U_j) (Exi+Ui)WqWk(Exj+Uj), 其中 W q W_q Wq W k W_k Wk为key和value对应的的矩阵,E为token的embedding,U为position encoding,展开如下
TransformerXL解读_第4张图片
TransformerXL对此做了如下改变TransformerXL解读_第5张图片

  1. 将计算key向量的绝对位置编码 U j U_j Uj换成了 R i − j R_{i-j} Rij, 这是一个sinusoid encoding matrix,不是学习得到的
  2. 将query向量的绝对位置向量替换成了可训练的向量u和v,这是因为这里采用相对位置编码,i位置的绝对编码没有意义
  3. W k W_k Wk替换成了分别基于位置(location-based)和内容(content-based)的矩阵,计算得到不同的key向量

总的来说,Relative Positional Encodings就是在计算注意力分数时,用相对位置 R i − j R_{i-j} Rij和学习了的相对位置 v v v u u u向量来代替绝对位置编码 U i U_i Ui U j U_j Uj

改造后的TransformerXL公式为:
TransformerXL解读_第6张图片

总结

TransformerXL通过循环机制利用上个segment的信息,并且将绝对位置编码改成相对位置编码,解决了普通Transformer无法建立超过固定长度文本的长依赖问题和context fragmentation,在预测的效率也大幅提升。

你可能感兴趣的:(NLP,自然语言处理,神经网络,机器学习,深度学习)