Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context翻译

原文链接: https://arxiv.org/pdf/1901.02860.pdf

github:https://github.com/kimiyoung/transformer-xl

摘要

Transformers具有学习长期依赖的潜力,但在语言模型的设置中受到固定长度上下文的限制。我们提出了一种新型的神经网络结构Transformer-XL,它可以在不中断时间一致性的情况下学习到超出固定长度的依赖性。它由段级循环机制和新的位置编码方案组成。我们的方法不仅可以捕获长距离依赖性,还可以解决上下文碎片问题。因此,Transformer-XL学习的依赖性比RNN长80%,比vanilla Transformers长450%,在短序列和长序列上都能获得更好的性能,并且在评估过程中比vanilla Transformers快1800倍。值得注意的是,我们将bpc / perplexity的最新结果在enwiki8上改进为0.99,在text8上改进为1.08,在WikiText-103上为18.3,在十亿个字上为21.8,在Penn Treebank上为54.5(没有微调)。当仅在WikiText-103上进行训练时,Transformer-XL成功地生成具有数千个字符的合理连贯的文本文章。在Github上,我们都提供了Tensorflow和PyTorch的代码,预训练模型和超参数。

1.介绍

构建长距离依赖性是语言模型需要解决的重要问题之一,模型在如无监督预训练方面具有成功的应用 。然而,为神经网络配备在序列数据中建模长距离依赖性的能力一直是一个挑战。循环神经网络(RNNs),特别是长短期记忆(LSTM)网络,已成为语言模型的标准解决方案,并在多个基准测试中获得了较好的结果。尽管其适应范围很广,但由于梯度消失和爆炸,RNNs难以进行优化,LSTM中门控机制的引入和梯度裁剪技术可能不足以完全解决这个问题。根据经验,以前的工作发现LSTM语言模型平均使用200个上下文单词,表明有进一步改进的空间。
  另一方面,在注意力机制中的长距离单词对之间的直接连接可以简化网络优化过程,并且能够学习长距离依赖性。最近,Al-Rfou等人设计了一套辅助损失(auxiliary losses),用于训练深度Transformer网络而进行字符级语言建模,其性能大幅优于LSTM。尽管取得了成功,但在Al-Rfou等人的LM训练中,在几百个字符的分离的固定长度段上执行,而没有任何信息流过段。由于固定的上下文长度,模型无法捕获超出预定义上下文长度的任何长距离依赖性。另外,通过选择连续的字符块而不考虑句子或任何其他语义边界来创建固定长度的片段。因此,该模型缺乏良好预测前几个字符所需的必要的上下文信息,导致低效优化和较差的性能。我们将此问题称为上下文碎片(context fragmentation)
  为了解决固定长度上下文的上述限制,我们提出了一种名为Transformer-XL的新架构(意思是超长)。我们将循环的概念引入我们深层的self-attention网络。特别是,我们不是从头开始计算每个新段的隐藏状态,而是重用先前段中获得的隐藏状态。重用的隐藏状态用作当前段的存储器,其在段之间建立循环连接。因此,可以建立非常长距离的依赖性,因为信息可以通过循环连接传播。同时,从前一段传递信息也可以解决上下文碎片的问题。更重要的是,我们展示了使用相对位置编码而不是绝对编码的必要性,以便能够在不引起时间混淆的情况下重用状态。因此,作为额外的技术贡献,我们引入了一种简单但更有效的相对位置编码公式,其推广的注意力长度长于训练期间观察到的长度。
  Transformer-XL在五个数据集上获得了很好的结果,从词级到特征级语言建模不等。 Transformer-XL还能够生成具有数千个字符的相对连贯的长文本文章(参见附录E),仅使用100M字符进行训练。
  我们的主要技术贡献包括在原始的self-attentive模型中引入循环概念并推导出一种新的位置编码方案。这两种技术形成了一套完整的解决方案,因为它们中的任何一种都不能解决固定长度的上下文问题。Transformer-XL是第一个self-attention模型,它在字符级和单词级语言建模方面比RNN实现了更好的结果。

2.相关工作

过去几年中,在语言模型领域已经见证了许多重大进展,包括但不限于设计新的架构以更好地编码环境,改进正则化和优化算法,加速Softmax计算,并丰富输出分布类型。
  为了捕获语言模型中的远距离上下文,一系列工作直接将更广泛的上下文的表示作为附加输入提供给网络。现有的工作范围从手工定义上下文表示的工作到依赖从数据中学习的文档级主题的其他工作。
  更广泛地说,在通用序列模型(序列标注,序列生成)中,如何捕获长距离依赖性一直是一个长期存在的研究问题。从这个角度来看,由于LSTM的普遍适应,已经花费了许多努力来缓解梯度消失的问题,包括更好的初始化,额外的丢失信号,增强的存储器结构 和其他修改RNN内部结构以简化优化的方法。与他们不同,我们的工作基于Transformer架构,并表明语言模型作为一项实际任务,可以从学习长距离依赖的能力中获益。

3.模型

给定一个字符语料 X = ( x 1 , . . . , x T ) X =(x_1,...,x_T) X=(x1,...,xT),语言模型的任务是估计联合概率 P ( X ) P(X) P(X),其通常被自动回归因式分解为 P ( X ) = ∏ t P ( x t ∣ x < t ) P(X)= \prod_t P(x_t| x_{<t}) P(X)=tP(xtx<t)。通过因式分解,问题减少到了去估计每个条件因子。在这项工作中,我们坚持使用标准神经网络方法来建模条件概率。具体地,使用可训练的神经网络对上下文 x < t x_{<t} x<t编码为固定大小的隐藏状态,其与单词词向量相乘以获得logits输出。然后将logits输入到Softmax函数中,从而在下一个字符上产生分类概率分布。

3.1 Vanilla Transformer语言模型

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context翻译_第1张图片
  为了将Transformer或self-attention应用于语言模型,核心问题是如何训练Transformer以将任意长的上下文有效地编码为固定大小的表示。假设具有无限的存储和计算能力,一个简单的解决方案是使用无条件Transformer解码器处理整个上下文序列,类似于前馈神经网络。 然而,这在实际的有限资源中通常是不可行的。
  一个可行但粗略的近似是将整个语料分成可管理大小的较短片段,并且仅在每个片段内训练模型,忽略来自先前片段的所有上下文信息。这是Al-Rfou等人采用的想法。我们将其称为Vanilla模型,并在图1(a)中将其可视化。在这种训练范式下,信息在前进或后退中都不会流过各个部分。使用固定长度上下文有两个关键限制:首先,最大可能的依赖长度是由段长度限制的,在字符级语言建模上是几百(Al-Rfou等,2018)。因此,尽管与RNN相比,self-attention机制受梯度消失问题的影响较小,但是vanilla模型不能充分利用这种优势。其次,尽管可以使用padding来作为句子或其他语义边界,但在实践中,由于效率的提高,将长文本简单地分成固定长度的片段已经成为标准做法。但是,简单地将序列分块为固定长度的段将导致上下文碎片问题,如第1节所述。
  在评估期间的每个步骤中,Vanilla模型也使用与训练中相同长度的片段,但仅在最后位置进行一次预测。然后,在下一步,该段仅向右移动一个位置,并且必须从头开始处理新段。如图1(b)所示,该过程确保每个预测字符能够利用训练期间使用的最长可能上下文,并且还减轻训练中遇到的上下文碎片问题。 但是,这种评估过程非常昂贵。我们将证明我们提出的架构能够显提高评估速度

3.2 基于分段循环的状态重用

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context翻译_第2张图片
  为了解决使用固定长度上下文的限制,我们建议在Transformer体系结构中引入循环机制。在训练期间,当模型处理下一个新段时,为前一段计算的隐藏状态序列被固定并且缓存以作为扩展上下文重用,如图2(a)所示。尽管梯度仍然保留在一个段内,但这个额外的输入允许网络利用历史段中的信息,从而能够建模长距离依赖性并避免上下文碎片。形式上,令长度为 L L L的两个连续段分别为 s τ = [ x τ , 1 , ⋅ ⋅ ⋅ , x τ , L ] s_{τ}=[x_{τ,1},···,x_{τ,L}] sτ=[xτ,1,,xτ,L] s τ + 1 = [ x τ + 1 , 1 , ⋅ ⋅ ⋅ , x τ + 1 , L ] s_{τ+1}=[x_{τ+1,1},···,x_{τ+1,L}] sτ+1=[xτ+1,1,,xτ+1,L]。将由第 τ τ τ s τ s_τ sτ产生的第 n n n层隐藏状态序列表示为 h τ n ∈ R L × d h^n_τ∈R^{L×d} hτnRL×d,其中 d d d是隐藏态维数。然后,用于段 s τ + 1 s_{τ+1} sτ+1的第 n n n层的隐藏状态表示如下:
h ~ τ + 1 n − 1 = [ S G ( h τ n − 1 ) ◦ h τ + 1 n − 1 ] , \widetilde{h}^{n−1}_{τ+1}=[SG(h^{n−1}_{τ} ) ◦ h^{n−1}_{τ+1}], h τ+1n1=[SG(hτn1)hτ+1n1],
q τ + 1 n , k τ + 1 n , v τ + 1 n = h τ + 1 n − 1 W q ⊤ , h ~ τ + 1 n − 1 W k ⊤ , h ~ τ + 1 n − 1 W v ⊤ , q^n_{τ+1}, k^n_{τ+1}, v^n_{τ+1}=h^{n−1}_{τ+1}W^{\top}_q , \widetilde{h}^{n−1}_{τ+1}W^{\top}_k , \widetilde{h}^{n−1}_{τ+1}W^{\top}_v , qτ+1n,kτ+1n,vτ+1n=hτ+1n1Wq,h τ+1n1Wk,h τ+1n1Wv,
h τ + 1 n = T r a n s f o r m e r L a y e r ( q τ + 1 n , k τ + 1 n , v τ + 1 n ) . h^{n}_{τ+1}=TransformerLayer(q^{n}_{τ+1}, k^{n}_{τ+1}, v^n_{τ+1}). hτ+1n=TransformerLayer(qτ+1n,kτ+1n,vτ+1n).
其中,函数 S G ( ⋅ ) SG(·) SG()代表停止梯度(stop-gradient)(即SG内的值不进行梯度反向传播),符号 [ h u ◦ h v ] [h_u◦h_v] [huhv]表示沿序列长度方向将两个隐藏序列连接起来, W . W. W.表示模型参数。与标准Transformer相比,关键的区别在于,键 k τ + 1 n k^n_{τ+1} kτ+1n和值 v τ + 1 n v^n_{τ+1} vτ+1n以扩展的上下文 h ~ τ + 1 n − 1 \widetilde{h}^{n-1}_{τ+1} h τ+1n1作为输入,因此 h τ n − 1 h^{n-1}_τ hτn1从前一个段缓存。我们通过图2(a)中的绿色路径强调这种特殊设计。
  将此循环机制应用于语料中的每两个连续段,它实质上在隐藏状态中创建了一个段级循环。因此,所使用的有效上下文能超够出两个部分。然而,注意到 h τ + 1 n h^{n}_{τ+1} hτ+1n h τ n − 1 h^{n-1}_τ hτn1之间的循环依赖性每段向下移动一层,这与传统RNN-LM中的同层循环不同。 因此,最大可能的依赖长度呈线性增长,并且与层数以及段长度有关,即 O ( N × L ) O(N×L) O(N×L),如图2(b)中的阴影区域所示。这类似于截断的BPTT(这是一种用于训练RNN-LM的技术)。但是,与截断的BPTT不同,我们的方法缓存一系列隐藏状态而不是最后一个隐藏状态,并且应该与3.3节中描述的相对位置编码技术一起应用。
  除了实现更长的上下文关联和解决碎片化之外,循环机制带来的另一个好处是显着加快了评估速度。具体地,在评估期间,可以重复使用来自先前段的表示,而不是像在vanilla模型的情况下从头开始计算。在我们的enwiki8实验中,Transformer-XL在评估过程中比vanilla模型快1800倍(参见第4节)。
  最后,请注意,循环机制不需要仅限于前一个段。理论上,我们可以在GPU内存允许的条件下缓存尽可能多的先前段,并在处理当前段时将所有段重用为额外的上下文。因此,我们将可以缓存隐藏状态数目预定义为长度 M M M,跨越(可能)多个段,并将它们称为存储器 m τ n ∈ R M × d m^n_τ∈R^{M×d} mτnRM×d,这是由于与记忆增强神经网络的明确联系(Graves等。 ,2014; Weston等,2014)。在我们的实验中,我们将M设置为训练期间的段长度,并在评估期间将其增加多倍。

3.3 相对位置编码

虽然我们发现前一小节中提出的想法非常吸引人,但是为了重用隐藏状态,我们还没有解决一个关键的技术挑战。也就是说,当我们重用隐藏状态时,我们如何保持位置信息的一致性? 回想一下,在标准Transformer中,输入序列的位置信息是由一组位置编码提供,表示为 U ∈ R L m a x × d U∈R^{L_{max}×d} URLmax×d,其中第 i i i U i U_i Ui对应于段内的第 i i i个绝对位置。 L m a x L_{max} Lmax规定了要建模的最大可能长度。然后,Transformer的实际输入是单词词向量和位置编码的叠加。如果我们简单地将这种位置编码调整到我们的循环机制,隐藏状态序列将按如下公式计算:
h τ + 1 = f ( h τ , E s τ + 1 + U 1 : L ) h τ = f ( h τ − 1 , E s τ + U 1 : L ) , h_{τ+1}=f(h_τ, E_{s_{τ+1}}+U_{1:L})\\ h_τ = f(h_{τ−1}, E_{s_τ} +U_{1:L}), hτ+1=f(hτ,Esτ+1+U1:L)hτ=f(hτ1,Esτ+U1:L),
其中 E s τ ∈ R L × d E_{s_τ}∈R^{L×d} EsτRL×d s τ s_τ sτ的词向量表示序列,f表示变换函数。注意, E s τ E_{s_τ} Esτ E s τ + 1 E_{s_{τ+1}} Esτ+1都与相同的位置编码 U 1 : L U_{1:L} U1:L相关联。因此,对于任何 j = 1 , . . . , L j=1,...,L j=1,...,L,模型没有信息来区分 x τ , j x_{τ,j} xτ,j x τ + 1 , j x_{τ+1,j} xτ+1,j之间的位置差异,这导致模型的性能损失。
  为了避免这种情况发生,基本思想是仅对隐藏状态中的相对位置信息进行编码。从概念上讲,位置编码为模型提供了关于应如何收集信息的时间线索或“偏差”,即在哪里加入。出于同样的目的,不是将偏差静态地结合到初始词向量中,而是可以将相同的信息注入到每层的注意力分数中。更重要的是,以相对方式定义时间偏差更直观和通用。例如,当一个问题向量 q τ , i q_{τ,i} qτ,i乘以键向量 k τ , ≤ i k_{τ,≤i} kτ,i时,不需要知道每个键向量的绝对位置以识别该段的时间顺序。相反,知道每个键向量 k τ , j k_{τ,j} kτ,j与其自身 q τ , i q_{τ,i} qτ,i,即 i − j i-j ij之间的相对距离就足够了。实际上,可以创建一组相对位置编码 R ∈ R L m a x × d R∈R^{L_{max}×d} RRLmax×d,其中第 i i i R i R_i Ri表示两个位置之间的 i i i的相对距离。通过将相对距离动态地注入注意力分数,问题向量可以容易地区分 x τ , j x_{τ,j} xτ,j x τ + 1 , j x_{τ+1,j} xτ+1,j与其不同距离的表示,使得状态重用机制可行。同时,我们不会丢失任何时间信息,因为绝对位置可以从相对距离递归地恢复。
  以前,在机器翻译和音乐生成的背景下探索了相对位置编码的概念。在这里,我们提供了一种不同的推导,得出了一种新形式的相对位置编码,它不仅与其绝对位置具有一对一的对应关系,而且在经验上也有更好的泛化(见第4节)。
  (1)绝对位置编码
  首先,在标准Transformer中,同一段内的问题 q i q_i qi和键向量 k j k_j kj之间的注意力得分可以分解为:
A i , j a b s = E x i ⊤ W q ⊤ W k E x j ⎵ ( a ) + E x i ⊤ W q ⊤ W k U j ⎵ ( b ) + U i ⊤ W q ⊤ W k E x j ⎵ ( c ) + U i ⊤ W q ⊤ W k U j ⎵ ( d ) . A^{abs}_{i,j}=\underbrace{E^{\top}_{x_i}W^{\top}_qW_kE_{x_j}}_{(a)}+\underbrace{E^{\top}_{x_i}W^{\top}_qW_kU_{j}}_{(b)}\\ +\underbrace{U^{\top}_{i}W^{\top}_qW_kE_{x_j}}_{(c)}+\underbrace{U^{\top}_{i}W^{\top}_qW_kU_{j}}_{(d)}. Ai,jabs=(a) ExiWqWkExj+(b) ExiWqWkUj+(c) UiWqWkExj+(d) UiWqWkUj.
  (2)相对位置编码
  根据仅依赖相对位置信息的想法,我们建议将上式中的4项重新参数化如下:
A i , j r e l = E x i ⊤ W q ⊤ W k E x j ⎵ ( a ) + E x i ⊤ W q ⊤ W k R i − j ⎵ ( b ) + u ⊤ W q ⊤ W k E x j ⎵ ( c ) + u ⊤ W q ⊤ W k R i − j ⎵ ( d ) . A^{rel}_{i,j}=\underbrace{E^{\top}_{x_i}W^{\top}_qW_kE_{x_j}}_{(a)}+\underbrace{E^{\top}_{x_i}W^{\top}_qW_kR_{i-j}}_{(b)}\\ +\underbrace{u^{\top}W^{\top}_qW_kE_{x_j}}_{(c)}+\underbrace{u^{\top}W^{\top}_qW_kR_{i-j}}_{(d)}. Ai,jrel=(a) ExiWqWkExj+(b) ExiWqWkRij+(c) uWqWkExj+(d) uWqWkRij.

你可能感兴趣的:(语言模型)