transformer者,传思法模也。 ——木盏
如果非要找一个模型来作为近三年来AI算法进展的突出代表,我认为transformer定会高票当选。
本文作为算法解析文章,倡导思想为主,公式为辅,希望有助于大家理解transformer。行文逻辑为总-分-总结构。本文所有未标来源的图片均为本人所画,引用时请附上本文链接。
本文首发于本人知乎:https://zhuanlan.zhihu.com/p/361740395 (排版不如CSDN,所以不用移步啦~)
transformer最早是用于机器翻译,所以我以汉译英来举例。
在transformer中,所输入句子中每个独立的字或者单词被当作单独个体(下文统一称为“节点”),每个节点又与句子整体是密不可分的。对于翻译句子而言,在处理某一节点信息时,需要了解句中其他节点对当前节点的“影响力”。比如,“我喜欢打篮球。”这个句子,翻译“我”的时候需要吸取其他节点的“意思”,这样才能给出准确地翻译。同样,处理其他节点时也会进行这个操作。这便是transformer中self-attention的思想,原文[1]描述如下:
Self-attention, sometimes called intra-attention is an attention mechanism relating different positions of a single sequence in order to compute a representation of the sequence.
如此一来,是不是可以跟卷积做一个类比?拿3x3卷积举例,处理中间pixel值会收到周围8个pixel的信息聚合。而transformer会聚合整个句子中所有词节点的信息,通常也被认为其“全局性”是好于CNN的。这也是transformer能在CV界兴风作浪的理论支持。
其实直接将transfomer这种信息聚合类比于卷积并不是最恰当的,更恰当的类比是GAT,即Bengio在2017年(稍晚于transformer)发表的Graph Attention Network. [4]。很巧合的是,他们都强调attention,并且都有Multi-head Attention的trick。有兴趣的可以看我以前写的博文《图注意力网络》。做一个学术总结:transformer是一种特殊的图注意力网络(将句子当作以每个单词为图节点、并且两两有连接的图)。国外有专门比较transformer和GNN思想的文章,有兴趣可戳Transformers are Graph Neural Networks。
不得不承认的是,transformer与GAT相比,把思想交集部分除外,更有很多超越之处:1. transformer并不是简单的加权聚合,而是构造了Q/K/V的概念,增强了权重的意义;2. 针对序列数据设计的encoder-decoder结构(图3所示),天然支持不等长输出,还能选择加入recurrence(如decoder)。
另一个常与transformer比较的工作是FAIR团队发表在CVPR2018的Non-local[5],时间稍晚于transformer,且引用了transformer论文。图4所示的non-local block跟transformer的encoder的attention部分(图5)相似度极高。只看attention部分的话,细心的你肯定会发现它们几乎只是画图风格的不同而已。从某种角度看,non-local应该属于早期将transformer思想应用于视觉的work之一。
本文直接分析最原始的transformer结构(如图6所示),图6中有红色标记的地方将会着重解答。
第一节讲述了transformer做了一件什么事情(聚合全局信息),这一节我们来看看具体实现。先看attention机制,把图6中“Multi-Head Attention”打开看:
先了解三兄弟:Key、Query、Value,对应图7中的K、Q、V。其中文翻译为:键、查询、值。
KQV都是从词嵌入向量通过矩阵运算得到,可以看作KQV是单词向量从不同方向的编码。
输入句子到transformer需要先对每个单字进行编码,用向量形式表达出这个单词,向量维度的典型值为512。而编码后的KQV三个向量通常都远小于嵌入向量的维度,典型值为64。如图8所示:
简单描述一下:一个512d的输入向量,直接点乘三个不同的转换矩阵得到三个64d的输出向量。需要强调的是,这里的都是通过训练得到的,而且对于同一层所有节点的KQV计算,其转换矩阵的参数都是共享的。
对于输入句子中的每一个节点都会计算出三个小向量KQV。两两节点之间的attention就是通过这三个向量计算得到的,先看attention公式:
公式(1)中的为键向量K的维度,典型值为64。这里的softmax为的是把权值整理到(0,1)的区间内,并且所有权值加到一起等于1,想了解更多可戳《softmax》。光看公式(1)很容易被误导:QKV都是从一个向量上提取出来的,现在将它们以某种形式合并,会不会多此一举。其实公式(1)计算的QKV未必是来自同一个输入向量。这个时候,我们来看看Jalammar[3]画的一张图:
由图9可知,QK相乘得到一个数值,在对这个数值做一些标准化操作(除以,再过一层softmax),就可以得到“注意程度”,这个“注意程度”乘以V就可以得到Attention。图9中的各个数值仅供示例,应该由数值准确度待考。
为了更好地解释attention计算,我把transformer的encoder重新画了一下:
如图10所示,attention计算之后的输出z表示加权后的结果。图9表示的是的计算结果。咱们公式化描述一下图9:
公式(2)则应该是attention聚合其他节点信息的表达,其中 n 表示节点个数(即句子长度)。公式(2)并非原文给出的公式,而是我根据原文[1]中关于attention描述写出来的,原文描述如下:
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.
原文提到的a query and a set of key-value pairs前半句反应出拿当前节点的Q去和所有节点(包括当前节点)的KV对做计算。output is computed as a weighted sum of the values后半句表示KQ计算可以得到权重,然后对V进行加权求和即可得到当前节点的输出z。
通过公式(2),我们以此可以算出(假设句子长度为n)。
继续看图10,沿着箭头网上看,对应位置的x和z求和。这里的z维度是64,而x维度是512,其实这里的z还要乘以一个变换矩阵到512维(稍后的multi-head attention将会讲这个)。然后送入到Layer Norm层(简称LN)。LN是做层归一化,相对于广泛认知的BN只是操作维度不一样而已,其做法主要是让输出分布在原点附近,增加稀疏性,降低过拟合的风险。想进一步了解LN/BN/GN/IN等的可以看我之前的文章《Normaliztion》。
再沿着图10的箭头网上看,可以看到每个位置都接了一个FeedForward Network(FFN)。直接看原文公式:
这不就是个激活函数为Relu的MLP嘛。不过需要注意的是,原文描述FFN时用了position-wise。针对position-wise,原文描述如下:
In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully connected feed-forward network, which is applied to each position separately and identically.
separately和identically用得就很灵性,这里的FFN对每个节点分别做,不会concat到一起再做,并且对每个节点的FFN参数都是一样的。
这样,transformer的encoder部分就差不多了。对了,还有multi-head attention。这个多头注意力,跟卷积弄出多个channel一个意思。直接把multi-head attention的流程梳理一遍:
图11表示的是,输入4个单词的句子经过3个head的attention示意图。每个单词都是一个512d的向量,那么4个单词就是4x512的矩阵。分别经过三组attention,这里每组attention都会计算公式(2),并且它们之间互不干涉。可以得到三组输出,然后concat它们再乘以一个转换矩阵Wo, 就可以得到与输入x同维度的矩阵。这也是图10中x可以直接与z相加的原因。
这里可以列一个公式来表述整个encoder计算:
所以,公式(6)就可以表达transformer的encoder计算了。文中[1]直接采用了6个这样的encoder堆叠到一起。
输入input embedding的时候,会直接加上一个positional encoding。这个很好理解:因为attention机制并不是position-aware的,任何单词节点在attention里都是同等对待(都是聚合全局的信息)。那么对翻译句子的任务而言,打乱了的句子跟有顺序的句子肯定是不一样的。所以,transformer就加入了位置编码,相当于注入了位置信息。
位置编码也就是通过一个特定的pattern计算出来的,不同的位置对应不同的位置编码向量。可以在[3]里看jalammar给的代码,也可以在[1]中直接看公式。本文不做进一步解释。
这里remark一下:transformer这种注入位置信息的方法略微有些朴素,如果这里得到改进说不定还能提升。
encoder被公式(6)基本上撂清了。咱们来看看薛薇复杂的decoder,直接看jalammar的图:
decoder比encoder多了一个Encoder-Decoder Attention(EDA)而已。这个EDA无非就是KV采用encoder输出,而Q采用上一decoder输出的Attention而已。在原文里描述如下:
In "encoder-decoder attention" layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This allows every position in the decoder to attend over all positions in the input sequence.
这么一来,任何decoder中任何一个节点位置都有全局信息。
问题1:第一个decoder的Q从哪里来?
答:这里就是decoder加入recurrence的地方了。把上一时刻及之前的模型输出放回decoder当输入(类似于RNN的操作)。
问题2:为什么decoder中的第一个attention需要加mask?
答:1. 训练过程中mask掉当前位置以后的所有节点,只用当前位置之前的信息当作输入,这样才不会”泄露答案“。2. 推理过程也可以mask掉后面的位置编码(第2点是我猜想的,待考)。
问题3:为什么图6的decoder输入有个shifted right?
答:因为当第一个输出单词没得到之前,输入到decoder的只能是空。所以,在训练的时候需要把GT往右边移一个节点。这个shifted right可以理解为,在时序上落后一个节点。
问题2和问题3在原文可以找到答案:
This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i an depend only on the known outputs at positions less than i.
问题4:Q,K,V到底是什么?又如何发挥作用?
答:将512维的输入向量从三个不同角度进行信息抽取得到3个64维的向量QKV。Q是查询信息,K是键信息,而V是内容信息。节点i的查询Qi,与节点j的键Kj的点乘可以反应节点i对节点k的影响程度,经过调整以后可以成为节点j聚合节点i信息的权重,而这个需要聚合的信息就是V。为什么可以这样?因为转换矩阵就是这么引导训练出来的。
关于transformer的推理过程演示,Jalammar这张动图就很灵性:
transformer可以通过输出一个结束标识符比如
链接1:Attention Is All You Need
链接2:Transformers are Graph Neural Networks
链接3:The Illustrated Transformer
链接4:Graph Attention Networks
链接5:Non-local Neural Networks
链接6:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale