【论文阅读笔记】Transformer-XL

Paper: Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
重点关注论文中的相对位置编码及提高融合了相对位置信息的attention score的计算效率的部分。

Abstract

Transformer具有学习长依赖的能力,但受限于语言模型固定长度上下文的限定。本文提出的Transformer-XL神经网络架构可以在不打破时序关系的前提下突破固定长度上下文的限制,学习文本间的依赖关系。模型具体包括一个片段级别的循环机制和一个全新的位置编码方式。该架构不仅可以学习文本中的长依赖关系,还解决了上下文碎片问题(context fragmentation problem)。最终,Transformer-XL可以习得相较RNN长80%、相较原始Transformer长450%的依赖关系,并且在评估时的速度最多比原始Transformer快1800多倍。作者还提供了Transformer-XL的Tensorflow和PyTorch的实现版本。

Introduction

本文关注的是基于神经网络的架构使得模型具备为序列数据的长依赖关系进行建模的能力的问题。RNN由于梯度消失和梯度爆炸的问题难以优化,即使是引入了门机制的LSTM和梯度裁剪技术,以上问题仍旧未能得到完全解决,同时一般而言,LSTM语言模型平均使用长度为200的上下文单词,因此尚有一定的提升空间。

另一方面,可以直接捕捉两个距离较远的单词之间关系的attention机制或许有助于实现长依赖关系的学习。相关的研究有很多,但受限于固定长度的上下文,模型无法捕捉那些长度超过预定义的上文长度的文本依赖关系。还有方法在不考虑句子或其它语义边界的情况下选择连续字符构成长度固定的片段(fixed-length segment)进行建模,但这样的模型在前几步的预测中缺乏必要的上下文信息,继而带来优化及性能方面的问题,本文将该方面的问题称为上下文碎片问题(context fragmentation)。

为了解决上述问题,本文提出了名为Transformer-XL的架构,其中XL意为extra long,该机制将循环的概念引入了深度自注意力网络中。具体而言,对于每一个新片段(segment),在计算其隐藏状态时会复用之前片段的隐藏状态,而非从头开始计算。复用的隐藏状态作为当前片段的记忆单元,从而建立起了片段之间的循环连接。这样的做法由于信息可以在片段的循环连接之间得以传播使得为特别长的依赖关系建模成为可能,同时也解决了上下文碎片的问题。此外。为了在复用状态时不会引发时许混淆的问题,本文展现了使用相对位置编码的必要性。一个简单但更高效的相对位置编码公式也有利于那些长度超过训练时的注意力长度的内容学习上的泛化。

Transformer-XL是首个同时在字符级别(character-level)和词级别(word-level)语言模型上超越RNN模型的自注意力模型。

Related Work

语言模型领域近年来的发展有很多,如设计更好的编码上下文的新架构、改进的正则化或优化算法、softmax计算的加速以及对输出分布的优化等等。

为了捕捉语言模型中的长范围的上下文,部分工作直接将更长的上下文表示作为附加输入送入神经网络中。现有的工作包括人为定义上下文表示以及从数据中学习篇章级别的主题等等。

更广泛而言,在一般的序列建模问题中,如何捕捉长依赖关系一直是一个长期存在的研究问题。由于LSTM的普适性,大量工作关注解决其梯度消失的问题,包括更好的参数初始化、附加的损失计算、增强的记忆单元结构以及一些修改RNN结构以便于优化的方法等。与这些做法不同的是,本文的工作基于Transformer架构,同时证明了学习长依赖关系的能力对现实任务中的语言建模的优势。

Model

给定token的语料库:,语言模型的任务是估计联合概率。基于因式分解,该问题简化为估计各条件因子。本工作采用标准的神经网络方法为各条件概率建模。具体而言,以一个可训练的神经网络将上下文编码为一个固定大小的隐藏状态,继而乘上词嵌入以获得其逻辑表示,该表示将送入softmax方程产生下一个token的概率分布。

Vanilla Transformer Language Models

将Transformer或自注意力机制用于语言模型的一个可行方案是,将整个语料划分为若干较短的可管理的片段,同时忽略之前片段的上下文信息,仅在各片段内部训练模型。本文将该模型称为Vanilla Model,其过程如Figure 1所示。

在该模型下的训练过程中,信息无法在片段间流动。使用固定长度的上下文存在两点关键限制:①可能获取的依赖长度上限由片段长度决定。而在字符级别的语言模型中,片段长度需要有好几百,即使自注意力机制能在一定程度上缓解RNN梯度消失的问题,但该模型认为充分利用自注意力机制的这一优化优势。②尽管可以通过padding延续文本的句子或其他语义边界特性,但事实上为了提高效率,简单将长文本划分成固定长度的片段已然成为标准做法,继而引发了上文提及的上下文碎片问题。

在评估阶段的每一步中,the vanilla model依旧采用训练阶段相同的片段长度,但仅对最后一个位置进行预测。而在下一步中,片段将向右平移一个位置,再重新从头开始处理整个片段进行当前片段最后位置的预测。如图所示,该过程确保每一次预测用到了训练阶段能看到的最长的上下文,同时缓解了训练阶段的上下文碎片问题。但相应的评估阶段的计算成本也有所提高。这一点在本文提出的框架中得以解决。

Segment-Level Recurrence with State Reuse

为了解决使用固定长度上下文带来的限制,本文提出在Transformer架构中引入循环机制,其过程如Figure 2所示。

在训练阶段,前一个片段计算得到的隐藏状态序列将被固定(fixed)并缓存起来(cached),在模型处理接下来的一个新片段时,刚刚缓存的隐藏层序列将作为一个扩展上下文进行复用。尽管梯度仍保留于每个片段内部,但这个附加的输入使得网络可以处理历史信息,继而使得模型可以对长依赖建模,同时避免了上下文碎片问题。该过程以公式化形式将表述如下,将两个长度为的连续片段分别表示如下:和;将第个片段的第层的隐藏状态序列表示为,其中是隐藏层状态维度。然后,将第个片段的第层的隐藏状态序列的计算过程如下:

\begin{array}{l} \widetilde{\mathbf{h}}_{\tau+1}^{n-1}=\left[\mathrm{SG}\left(\mathbf{h}_{\tau}^{n-1}\right) \circ \mathbf{h}_{\tau+1}^{n-1}\right] \\ \mathbf{q}_{\tau+1}^{n}, \mathbf{k}_{\tau+1}^{n}, \mathbf{v}_{\tau+1}^{n}=\mathbf{h}_{\tau+1}^{n-1} \mathbf{W}_{q}^{\top}, \widetilde{\mathbf{h}}_{\tau+1}^{n-1} \mathbf{W}_{k}^{\top}, \widetilde{\mathbf{h}}_{\tau+1}^{n-1} \mathbf{W}_{v}^{\top} \\ \mathbf{h}_{\tau+1}^{n}=\text { Transformer-Layer }\left(\mathbf{q}_{\tau+1}^{n}, \mathbf{k}_{\tau+1}^{n}, \mathbf{v}_{\tau+1}^{n}\right) \end{array}

其中函数表示停止梯度计算,表示两个隐藏层序列沿length维度的拼接,表示模型参数。与标准Transformer相比,关键的不同点在于,key 和value 基于扩展后上下文得来,因此可从之前的片段中获取信息。Figure 2(a)中由绿色路径标注了本文的特殊设计。

通过在每两个连续片段之间应用循环机制,建立起了隐藏层片段级别的循环。因此有效的上下文信息将不仅仅在两个片段内被利用。然而需要注意的是,和之间的循环依赖每一片段将向下移动一层,这与传统的基于RNN的语言模型中的同层循环是有所不同的。最终,最长依赖长度随层数和片段长度呈线性增长,即,如Figure 2(b)的阴影部分所示。这与一训练基于RNN的语言模型采用的方法truncated BPTT类似。但与其不同的是,本文提出的方法缓存的是隐藏状态序列而非上一序列,同时还应该结合后文将介绍的相对位置编码技术一起使用。

该架构除了能利用更长的上下文以及解决上下文碎片问题外,循环机制还使得评估时的效率显著提高。具体而言,在评估阶段,之前片段的表示可以同the vanilla model一样进行重用。

最后需要注意的是,循环机制不必仅限于邻接的前一个片段。理论上,在GPU内存允许的情况下,可以缓存尽可能多的之前的片段,并在处理当前片段时复用所有的这些片段作为额外的上下文。因此,可以缓存一个预定义的长度——个旧的隐藏状态,并将它们表示为记忆单元。实验中,本文将设为等同于片段长度的大小,并在评估中,将其值加倍增长。

Relative Positional Encoding

上述方案存在的问题是重用隐藏层状态的顺序问题,即在重时是如何保证位置信息的连贯问题(the positional information coherent)。在标准Transformer中,序列顺序信息是通过一个位置编码集合提供的,其中第行表示某一片段中第个绝对位置,表示建模的最大长度;随后输入将有文本的词嵌入表示和位置编码相加得来。倘若将这样的位置编码方式直接运用到本文的循环机制中,隐藏状态序列的计算如下:

其中表示序列的词嵌入,表示一个转换方程。需要注意的是,和用到了同样的位置编码。因此,对于任意的,模型没有用于分辨和位置区别的信息,继而造成严重的性能损失。

为了避免上述问题,最基础的想法是在隐藏状态中仅编码相对位置信息。从概念上来说,位置编码给予了模型如何汇聚信息的时序线索。出于同样的目的,可以在每一层中将类似的信息映射到attention分值上。更重要的是,以相对位置定义时序偏差是更直观且有利于泛化的。例如,当一个query向量在key向量上计算注意力时,无需了解每一个key向量的绝对位置,了解每一个key向量和自身的相对位置即可反映片段内的时序关系。在实践上,可以创建一个相对位置编码集合,其中第行表示i和其它位置的相对距离。通过将相对位置动态地映射到注意力分值中,query向量可以轻松地根据不同的距离区分和的表示,继而使得状态重用机制可行。与此同时,由于绝对位置信息可以递归地从相对距离中获取,时序信息并未丢失。

过去,相对位置编码的思想已被用于机器翻译和音乐生成任务中。这里,本文提出一种不同的相对位置编码新形式的推导,不仅与其绝对位置有一对一的对应关系,而且具有更好的泛化能力。首先,在标准Transformer中,同一片段内的query向量和key向量之间的注意力得分计算可做如下分解:
\begin{aligned} \mathbf{A}_{i, j}^{\text {abs }} &= q_i^\top k_j \\ &=[\mathbf{W}_q (\mathbf{E}_{x_i}+\mathbf{U}_i)]^\top [\mathbf{W}_k(\mathbf{E}_{x_j}+\mathbf{U}_j)] \\ &= \underbrace{(\mathbf{E}_{x_{i}}^{\top} \mathbf{W}_{q}^{\top}) (\mathbf{W}_{k} \mathbf{E}_{x_{j}})}_{(a)}+\underbrace{(\mathbf{E}_{x_{i}}^{\top} \mathbf{W}_{q}^{\top}) (\mathbf{W}_{k} \mathbf{U}_{j})}_{(b)} \\ &+\underbrace{(\mathbf{U}_{i}^{\top} \mathbf{W}_{q}^{\top}) (\mathbf{W}_{k} \mathbf{E}_{x_{j}})}_{(c)}+\underbrace{(\mathbf{U}_{i}^{\top} \mathbf{W}_{q}^{\top}) (\mathbf{W}_{k} \mathbf{U}_{j})}_{(d)} \end{aligned}
根据仅依赖相对位置信息的思想,将上式中的四项重新参数化如下:
\begin{aligned} \mathbf{A}_{i, j}^{\text {rel }} &=\underbrace{(\mathbf{E}_{x_{i}}^{\top} \mathbf{W}_{q}^{\top}) (\mathbf{W}_{k,E} \mathbf{E}_{x_{j}})}_{(a)}+\underbrace{(\mathbf{E}_{x_{i}}^{\top} \mathbf{W}_{q}^{\top}) (\mathbf{W}_{k,R} {\color{blue}{\mathbf{R}_{i-j}}})}_{(b)} \\ &+\underbrace{{\color{red}{u^{\top}}} (\mathbf{W}_{k,E} \mathbf{E}_{x_{j}})}_{(c)}+\underbrace{{\color{red}{v^{\top}}} (\mathbf{W}_{k,R} {\color{blue}{\mathbf{R}_{i-j}}})}_{(b)} \end{aligned}

  • 将所有在公式项和中出现的计算key向量用到的绝对位置编码替换为对应的绝对位置编码。这从本质上反映了只考虑相对位置的前提。需要注意的是,是没有可训练参数的正弦编码矩阵。
  • 引入了可训练参数来替代公式项中的query向量。在采用相对位置编码的情况下,无论是哪个查询位置,此处的query向量应是一致的,其位置信息由相对位置编码反映,因此此处采用一个可训练的参数表示。出于同样的原因,将公式项中的替换为可训练参数。
  • 将两个权重矩阵和区别开来,以分别表示基于内容的key向量和基于位置的key向量。

在这样全新的参数化表示下,每一项都具备一个直观的含义:表示内容上的关联(content-based addressing);捕捉了依赖内容的位置偏差;控制着全局内容偏差;编码了全局位置偏差。

综上,带有单个注意力头的层的Transformer-XL的计算过程如下:对于:
\begin{array}{l} \widetilde{\mathbf{h}}_{\tau}^{n-1}=\left[\mathrm{SG}\left(\mathbf{m}_{\tau}^{n-1}\right) \circ \mathbf{h}_{\tau}^{n-1}\right] \\ \mathbf{q}_{\tau}^{n}, \mathbf{k}_{\tau}^{n}, \mathbf{v}_{\tau}^{n}=\mathbf{h}_{\tau}^{n-1} \mathbf{W}_{q}^{n \top}, \widetilde{\mathbf{h}}_{\tau}^{n-1} \mathbf{W}_{k,E}^{n \top}, \widetilde{\mathbf{h}}_{\tau}^{n-1} \mathbf{W}_{v}^{n \top} \\ \mathbf{A}_{\tau, i, j}^{n} = \mathbf{q}_{\tau, i}^{n \top} \mathbf{k}_{\tau , j}^{n} + \mathbf{q}_{\tau, i}^{n \top} \mathbf{W}_{k,R}^{n} \mathbf{R}_{i-j} + u^\top \mathbf{k}_{\tau , j} + v^\top \mathbf{W}_{k,R}^n \mathbf{R}_{i-j} \\ \mathbf{a}_{\tau}^{n}=\text { Masked-Softmax }\left(\mathbf{A}_{\tau}^{n}\right) \mathbf{v}_{\tau}^{n} \\ \left.\mathbf{o}_{\tau}^{n}=\text { LayerNorm(Linear }\left(\mathbf{a}_{\tau}^{n}\right)+\mathbf{h}_{\tau}^{n-1}\right) \\ \mathbf{h}_{\tau}^{n}=\text { Positionwise-Feed-Forward }\left(\mathbf{o}_{\tau}^{n}\right) \end{array}

初始化为词嵌入序列。此外,计算的效率随序列长度呈二次方变化。下面将介绍一个对的高效计算方式,其效率随序列长度呈线性变化。

Efficient Computation of the Attention with Relative Positional Embedding

倘若以基本方法计算考虑相对位置的attention score,其中对于所有对的计算呈二次方的消耗。因此,本文提出一种线性消耗的计算方法。已知,相对距离的只能是到的整数值,M是记忆长度,L是片段长度。令,则:
\mathbf{Q}=\left[\begin{array}{c} \mathbf{R}_{M+L-1}^{\top} \\ \mathbf{R}_{M+L-2}^{\top} \\ \vdots \\ \mathbf{R}_{1}^{\top} \\ \mathbf{R}_{0}^{\top} \end{array}\right] {\mathbf{W}_{k, R}}^{\top}=\left[\begin{array}{c} {\left[\mathbf{W}_{k, R} \mathbf{R}_{M+L-1}\right]^{\top}} \\ {\left[\mathbf{W}_{k, R} \mathbf{R}_{M+L-2}\right]^{\top}} \\ \vdots \\ {\left[\mathbf{W}_{k, R} \mathbf{R}_{1}\right]^{\top}} \\ {\left[\mathbf{W}_{k, R} \mathbf{R}_{0}\right]^{\top}} \end{array}\right] \in \mathbb{R}^{(M+L) \times d}
令,则

接下来,对于attention score中的(b)项,收集所有的对,形成如下一个的矩阵:
\begin{aligned} \mathbf{B} &=\left[\begin{array}{ccccccc} q_{0}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{M} & \cdots & q_{0}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{0} & 0 & \cdots & 0 \\ q_{1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{M+1} & \cdots & q_{1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{1} & q_{1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{0} & \cdots & 0 \\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots \\ q_{L-1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{M+L-1} & \cdots & q_{L-1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{M+L-1} & q_{L-1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{L-1} & \cdots & q_{L-1}^{\top} \mathbf{W}_{k, R} \mathbf{R}_{0} \end{array}\right] \\ &= \left[\begin{array}{ccccccc} q_0^\top \mathbf{Q}_{L-1} & \cdots & q_0^\top \mathbf{Q}_{M+L-1} & 0 & \cdots & 0 \\ q_1^\top \mathbf{Q}_{L-2} & \cdots & q_1^\top \mathbf{Q}_{M+L-2} & q_1^\top \mathbf{Q}_{M+L-1} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots & \ddots & \vdots \\ q_{L-1}^\top \mathbf{Q}_{0} & \cdots & q_{L-1}^\top \mathbf{Q}_{M} & q_{L-1}^\top \mathbf{Q}_{M+1} & \cdots & q_{L-1}^\top \mathbf{Q}_{M+L-1} \end{array}\right] \end{aligned}
接下来,定义一个新矩阵:
\widetilde{\mathbf{B}} = \mathbf{qQ}^\top = \left[\begin{array}{ccccccc} q_0^\top \mathbf{Q}_{0} & \cdots & q_0^\top \mathbf{Q}_{M} & q_0^\top \mathbf{Q}_{M+1} & \cdots & q_0^\top \mathbf{Q}_{M+L-1} \\ q_1^\top \mathbf{Q}_{0} & \cdots & q_1^\top \mathbf{Q}_{M} & q_1^\top \mathbf{Q}_{M+1} & \cdots & q_1^\top \mathbf{Q}_{M+L-1} \\ \vdots & \vdots & \ddots & \vdots & \ddots & \vdots \\ q_{L-1}^\top \mathbf{Q}_{0} & \cdots & q_{L-1}^\top \mathbf{Q}_{M} & q_{L-1}^\top \mathbf{Q}_{M+1} & \cdots & q_{L-1}^\top \mathbf{Q}_{M+L-1} \\ \end{array}\right]
将和比较可以发现将的第\mathbf{B}i$行。

类似的,对于attention score中的项,收集所有的对,形成如下一个的矩阵:
\mathbf{D}= \left[\begin{array}{ccccccc} v^\top \mathbf{Q}_{L-1} & \cdots & v^\top \mathbf{Q}_{M+L-1} & 0 & \cdots & 0 \\ v^\top \mathbf{Q}_{L-2} & \cdots & v^\top \mathbf{Q}_{M+L-2} & v^\top \mathbf{Q}_{M+L-1} & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots & \ddots & \vdots \\ v^\top \mathbf{Q}_{0} & \cdots & v^\top \mathbf{Q}_{M} & v^\top \mathbf{Q}_{M+1} & \cdots & v^\top \mathbf{Q}_{M+L-1} \end{array}\right]
同样,可以定义:

此时的每一行可由向左平移得来。
上述方法中,平移的消耗较少,主要的计算量在于和的矩阵乘法上,从而效率得以提升。

你可能感兴趣的:(【论文阅读笔记】Transformer-XL)