Transformer解释

Transformer解释_第1张图片

和seq2seq模型一样,transformer也是encoder和decoder组成 。

Transformer 的时间复杂度为 O(LN^2H),L为模型层数,H是注意力个数,N表示输入序列长度。

Attention 的时间复杂度主要受相似度计算及加权和的计算决定,n * d d/h * n  -> O(nd),n:输入序列长度,d:向量维度,h:注意力个数。

Encoder由N=6个相同的layer组成,layer指的就是上图左侧的单元,最左边有个“Nx”,此处x6个。每个Layer由两个sub-layer组成,分别是multi-head self-attention mechanism和fully connected feed-forward network。其中每个sub-layer都加了residual connection和normalisation,因此可以将sub-layer的输出表示为:

Attention的形式:

多头attention:通过h个不同的线性变换对Q,K,V进行投影,最后将不同的attention结果拼接。

Transformer解释_第2张图片

多头Attention可对不同部分进行注意,获得更多的表示能力,类似CNN网络中多个滤波器的效果。

Transformer 中使用 Layer Normalization(LN) 而不是 Batch Normalization(BN),因为Transformer是多头并行,而BN是基于Mini-Batch的,需要等待这个mini-batch数据输入完成才能继续训练,对处理不同长度文本时,计算均值和方差可能会出现偏差,在测试集上效果不好,特别是测试集样本长度分布不同于训练集。

LN不依赖外部数据,只依赖于当前层的输出,更好的适应不同长度输入数据,不受其他头影响。

Position-wise feed-forward networks:提供非线性变换。Attention输出的维度是[bsz * seq_len, num_heads * head_size],第二个sub-layer是个全连接层,position-wise因为过线性层时每个位置 i 的变换参数是一样的。

Decoder和Encoder的结构相似,多了一个attention的sub-layer,decoder的输入输出和解码过程:

  • 输出:对应 i 位置的输出词的概率分布
  • 输入:encoder 的输出及对应 i - 1 位置decoder的输出。中间的attention不是self-attention,其K,V来自encoder,Q来自上一位置decoder的输出。
  • 解码:训练和预测是不一样的。在训练时,解码是一次全部decode出来,用上一步的ground truth来预测(mask矩阵也会改动,让解码时看不到未来的token);而预测时需要一个个预测。

新加的attention多加了一个mask,因训练时的output都是ground truth,可确保预测第i个位置时不会接触到未来的信息;加了mask的attention原理如图(另附multi-head attention):

Transformer解释_第3张图片

 两种Positional Encoding的方法:

  1. 用不同频率的sine和cosine函数直接计算
  2. 学习出一份positional embedding

实验发现两者的结果一样,所以最后选择了第一种方法,公式如下:

Transformer解释_第4张图片

上述位置计算的优势:

Transformer解释_第5张图片 如果是学习到的positional embedding,会像词向量一样受限于词典大小;也就是只能学习到 “位置2对应的向量是(1,1,1,2)” 这样的表示。而用三角公式明显不受序列长度的限制,也就是可以对比所遇到序列的更长的序列进行表示。

Transformer的一些缺点:

  1. 实践上:有些rnn轻易可以解决的问题transformer没做到,比如复制string,或者推理时碰到的sequence长度比训练时更长(因为碰到了没见过的position embedding)
  2. 理论上:transformers非computationally universal(图灵完备)

Transformer是第一个用纯attention搭建的模型,不仅计算速度更快,在翻译任务上获得了更好的结果,也为后续的BERT模型做了铺垫。

参考:

【NLP】Transformer模型原理详解 - 知乎 (zhihu.com)

The Annotated Transformer

你想要的Transformer这里都有 - 知乎

你可能感兴趣的:(nlp,自然语言处理)