首先我们介绍下seq2seq,它首次应用在机器翻译的seq2seq,也就是enoder-decoder架构。论文见《Sequence to Sequence Learning with Neural Networks》
我们以RNN举例说明,seq2seq是将输入单词的embedding输入逐步输入encoder中,每个时刻encoder的输出取决于当前时刻的输入和上一时刻的隐状态(即上一时刻的输出),最后的隐状态作为decoder的输入,decoder之后的输出也是取决与上一时刻的隐状态和上一时刻的输出单词的embedding,最终输出的单词是decoder输出的隐状态全联接softmax之后得到最大概率的那个单词,这样最后一步步输出单词序列。这样做的缺点是当输入序列较长的时候,只靠encoder的最后状态很难捕捉前后依赖关系,因此引入注意力机制即attention。
注意力机制最早出现在CV领域,表示人看一物体其实对不同部分的注意力不一样,对关键吸引人的地方注意力会多些。论文见《Neural Machine Translation by Jointly Learning to Align and Translate》
attention的具体应用如下:对于seq2seq,为了解决上述缺点,每一时刻预测输出单词的时候不止依赖decoder的隐状态,还依赖encoder的输入历史信息,这里我们定义context vector表示它,context vector为输入encoder每个时刻隐状态的线性加权平均得到。
假设输入序列长度为 T T T,则 j j j时刻的context vector计算 c j ^ = ∑ i = 0 T α i j h ^ j \hat{c_j}=\sum_{i=0}^{T}\alpha_{ij}\hat h_j cj^=i=0∑Tαijh^j其中系数 α i j \alpha_{ij} αij为第j个输出对i时刻encoder隐状态的注意力(意思就是对当前时刻输出的重要性), 并且 ∑ i = 0 T α i j = 1 \sum_{i=0}^T{\alpha_{ij}} = 1 ∑i=0Tαij=1。其中第j个输出i时刻输入的注意力 α i j \alpha_{ij} αij的计算公式为 α i j = e x p ( e i j ) ∑ k = 0 T e x p ( e i k ) \alpha_{ij}=\frac{exp(e_{ij})}{\sum_{k=0}^Texp(e_{ik})} αij=∑k=0Texp(eik)exp(eij)
其中 e i j = s c o r e ( h ^ i , s ^ j ) e_{ij}=score(\hat h_i, \hat s_j) eij=score(h^i,s^j)(score可以是点乘dot), s ^ j \hat s_j s^j为decoder第 j j j个输出的隐状态, h ^ i \hat h_i h^i为 i i i时刻encoder输入的隐状态. 计算完后的context vector涵盖之前所有encoder的输入信息,和上一时刻的输出单词embeding进行concat作为新的decoder输入,这里注意初始context vector为全0向量。所以说这样相比不带attention的seq2seq能对输入序列的前后依赖关系更好的建模。
可以发现seq2seq的enoder-decoder架构加入attention机制后,对长序列输入的前后依赖关系能更好的建模。然而它也存在缺点,因为enoder-decoder架构均采用RNN建模,RNN的天然本质注定很难并行(当前时刻输入依赖上一时刻的输出),并且存在梯度爆炸弥漫等问题(LSTM, GRU等改进只能减少不能完全避免),CNN是通过叠加感受野的方式(需要非常深的网络才能捕捉较长的依赖),而attention有着无视距离的优势并且可以并行的优势,因此提出了抛弃RNN而纯依赖attention对输入序列的前后依赖进行建模的transformer,论文见《Attention is all you need》
回顾seq2seq中的attention,是对encoder的历史隐状态做线性加权平均,系数越高的隐状态则表示当前时刻的输出和其越相关。这个系数是从当前时刻decoder的隐状态和encoder的隐状态交互计算(比如dot product)后取softmax得到的,其实我们可以看成为以下机制
如图,给定query进行kv对查询,分别计算query和不同key的相似度,经过softmax后得到注意力系数,后对value加权平均得到attention value.即如下公式
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dkQKT)V
论文用的就是这种scaled dot product的attention机制,图片表示如下:
这里Q为 n × d k n\times d_k n×dk的矩阵,K为 m × d k m\times d_k m×dk的矩阵,V为 m × d k m\times d_k m×dk的矩阵,因此,一个attention层可以看作把 n × d k n\times d_k n×dk的输入序列编码成了 n × d v n\times d_v n×dv的序列
注意,这里scale by 1 d k \frac{1}{\sqrt{d_k}} dk1是为了平滑softmax(当 d k d_k dk过大导致dot product过大的值,softmax的梯度过小)
特别的,对于self-attention就是K,Q, V均相等的情况,自己对自己计算相互的依赖关系。
编码器为相同N=6个layer进行stack得到,每层有两个子层,分别为multi-head的self-attention和fully connect layer,这里采用了ResNet的思想建立了一个short-cut, LayerNorm(x + Sublayer(x))。
Encoder的输入,最底层为sequence的embedding,之后每层的输入为上一层的输出。
解码器同样为N=6个相同的layer进行stack得到。和encoder不同的是,这里有两个multi-head attention。最底下的masked multi-head self-attention是为了保证输出只能看到当前时刻之前的输出结果。上一层的multi-head attention也叫encoder-decoder attention,它并不是self-attention,是用底层的multi-head self-attention attention的输出作为Q,encoder的输入作为K和V,这样每次decoder的输出都携带了各个位置的输入信息,类似于传统的seq2seq的attetnion机制。
Decoder当前时刻的输入为encoder的输出和Decoder所看到的所有词的embedding,初始时刻假设看到的词一般为特殊字符,比如<\s>,图中即shifedt right
论文里的multi-head计算如下
这里head的个数为h=8. 每个head和刚才的注意力计算公式意义,第 i i i个头用 W i Q W_i^Q WiQ, W i K W_i^K WiK W i V W_i^V WiV进行映射,最后映射到纬度为 d v d_v dv,这样不同的head的Q,K,V就被映射到不同的子空间进行学习. W O W^O WO将最后纬度还原成 d m o d e l d_{model} dmodel
transformer并没有用到CNN或者RNN类似的结构,输入的embedding无法体现词的位置信息,因此对于输入的embedding,需要增加含位置信息的position encoding
position encoding公式如上图所示,pos就是位置,i是embedding向量的所在位置,由公式可以知道奇数位置采用cos,偶数位置采用sin;最终得到的position encoding向量与初始的embedding相加最为最终的输入。
对于为什么要用这个函数,论文里是这样说的
We chose this function because we hypothesized it would allow the model to easily learn to attend by relative positions, since for any fixed offset k k k, P E p o s + k PE_{pos+k} PEpos+k can be represented as a linear function of P E p o s PE_{pos} PEpos
作者更多的是凭经验得到的,因为他想让模型更好的学会相对位置关系,在这个公式中,对于embedding向量的某个固定位置 k k k的数值 P E p o s + k PE_{pos+k} PEpos+k,都可以用初始位置 P E p o s PE_{pos} PEpos线性表示。这点在数学利用三角函数的性质是可以证明的。
优点 :计算量减少,可并行计算增加,远程依赖关系之间的路径长度小(通过矩阵dot poduct即可建立依赖关系)
缺点:有些rnn可以解决的问题transformer做的不好,比如copy string,或者inference时碰到的sequence长度比training更长的处理很差;图灵不完备