Attention的应用感觉非常广泛,因此为了提升一下自己对Attention的理解就读了 A t t e n t i o n I s A l l Y o u N e e d Attention\ Is\ All\ You\ Need Attention Is All You Need [1]这篇文章,同样受到了很大的启发。虽然我感觉文章讲得还是挺乱的,我就大致按照我的理解,按照我理解一个模型的思路去讲讲。
图 一 T r a n s f o r m e r 的 结 构 图一\ Transformer的结构 图一 Transformer的结构
其实我感觉可能因为我没get作者的逻辑,所以不太能理清作者讲得顺序,我还是更喜欢从输入到输出,一步接着一步来讲,前一步和下一步进行联系,这样能更直观的理解,因此我也会这样讲。
这个就是一个全连接层,将n维的输入转换为 d m o d e l d_{model} dmodel维的,一般 d m o d e l d_model dmodel会选为512。值得一提的是,这里的全连接矩阵与Output Embedding和Linear(后文会提到的 p r e − s o f t m a x pre-softmax pre−softmax)会进行参数共享,即是同一个矩阵。
由于Transformer几乎没有位置信息,因此需要给输入的表征里蕴含位置信息,因此作者使用了一种Positional Encoding的方式,即
对 于 偶 数 位 + = sin ( p o s / 1000 0 2 i / d m o d e l ) 对 于 奇 数 位 + = cos ( p o s / 1000 0 2 i / d m o d e l ) 对于偶数位\ +=\sin(pos/10000^{2i/d_{model}})\\ 对于奇数位\ +=\cos(pos/10000^{2i/d_{model}}) 对于偶数位 +=sin(pos/100002i/dmodel)对于奇数位 +=cos(pos/100002i/dmodel)
然后取得了很好的效果,感觉非常神奇,可能attention自己就能学到这种时序依赖关系?
这里将2.2的结果输入到一个Encoder层,再将Encoder层的结果再输入到Encoder层,如此重复6次,再输入到后面,而Encoder层主要又由下面四个小部分一步一步构成。
这个是我觉得是Transformer最巧妙的地方
首先是提出了一种称为Scaled Dot-Product Attention的注意力机制。
图 2 S c a l e d D o t − P r o d u c t A t t e n t i o n 图2\ Scaled\ Dot-Product\ Attention 图2 Scaled Dot−Product Attention
注意力机制按我的理解就是为了聚合一个sequence里的重要信息,本文改良的注意力机制的式子为
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dkQKT)V
其中,Q被称为query,K被称作key,V被称作value,均为 b a t c h s i z e ∗ d m o d e l batchsize * d_{model} batchsize∗dmodel。我觉得形象上的理解,就是有一个询问query,要去找到value中对query重要的部分,那怎么做呢,先让query和某个能形容它重要性的key相乘,得到之后再去和V相乘即可,K一般就是V。在encoder中Q=K=V,都是2.2的以batchsize堆叠的输出。decoder中的会再解释。
文中指出,除以 d k \sqrt{d_k} dk是为了防止梯度爆炸。
图 3 多 头 注 意 力 机 制 图3\ 多头注意力机制 图3 多头注意力机制
Q、K、V还是上面那个Q、K、V,但是呢,考虑到一次的Attention可能没有足够的能力去学到所有的信息,那就学 h h h个矩阵 W i Q 、 W i k 、 W i v W_i^Q、W_i^k、W_i^v WiQ、Wik、Wiv,将Q、K、V映射到不同的维度上去,再拼到一起,作为注意力机制最后的结果,即
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , ⋯ , h e a d h ) W O h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) MultiHead(Q,K,V) = Concat(head_1,\cdots, head_h)W^O\\ head_i = Attention(QW_i^Q,KW_i^K,VW_i^V) MultiHead(Q,K,V)=Concat(head1,⋯,headh)WOheadi=Attention(QWiQ,KWiK,VWiV)
其中 W i Q W_i^Q WiQ是一个 d m o d e l × d k d_{model}\times d_k dmodel×dk的矩阵, W i K W_i^K WiK是一个 d m o d e l × d k d_{model}\times d_k dmodel×dk的矩阵, W i V W_i^V WiV是一个 d m o d e l × d v d_{model}\times d_v dmodel×dv的矩阵, W O W_O WO是一个 h d v × d m o d e l hd_v\times d_{model} hdv×dmodel的矩阵,用来再归一成 d m o d e l d_{model} dmodel,一般h=8, d k = d v = d m o d e l / h = 64 d_k=d_v=d_{model}/h=64 dk=dv=dmodel/h=64
这里主要是使用了一个残差链接的方法,即将前一部分的结果和前前一部分的结果加起来,这里是multi-head-attention的结果和encoder层的输入加起来,其他部分也可以通过图1来看。加完了之后再将结果通过一个LayerNorm层,这里文章直接引用了一篇文章的方法,应该就是一个LayerNorm的方法,即本层是
o u t = L a y e r N o r m ( p r e ( o u t ) + p r e p r e ( o u t ) ) out = LayerNorm(pre(out)+pre_pre(out)) out=LayerNorm(pre(out)+prepre(out))
这里主要是两层全连接,只不过过完第一层全连接之后加了个RELU,即
o u t = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 out = max(0,xW_1+b_1)W_2+b_2 out=max(0,xW1+b1)W2+b2
同2.3.2
这个和2.1和2.2基本相同,只是把目标串进行了Embedding,而且由于预测的往往是下一个,所以这个output要输入的是下一个时间的结果,Embedding之后和2.2同样做Positional Encoding即可。又值得一提的是,本文有两处left和right和直观来看都是相反的,但代码和直观是一样的。
Decoder和Encoder结构类似,只是增加了一个Masked的Multi-Head Attention,也是整个Decoder层被叠了6遍,一个Decoder层内部总共分为6部分。
这里基本和2.3.1的Multi-Head Attention一致,Q和K和V都是堆叠成的output。不过注意到Q和K矩阵相乘的时候,它的几何意义是Q的每一个时间点,去和K的每一个时间点做点积,这样会和一些未来时刻的信息做交互,这样显然是不行的,因此我们在Q和K矩阵相乘之后,要去Mask掉与未来做交互的位置,如图2所示,然后再得到结果。
同2.3.2
和2.3.1类似,不过重要的是K矩阵和V矩阵都是Encoder层最后的输出,只有Q矩阵是2.5.2的输出
同2.3.2
同2.3.3
同2.3.2
目前模型通过2.5之后,得到的是一个 B a t c h s i z e ∗ d m o d e l Batchsize * d_{model} Batchsize∗dmodel的矩阵,然后我们最后肯定想知道,预测出每一个词的概率,这时呢,我们再利用Input Embedding学到的 N ∗ d m o d e l N*d_{model} N∗dmodel的矩阵,对Batchsize里的每一个,和Input Embedding用到的矩阵的每一项做点积,得到一个数值,这样就得到了结果为每个词的置信度了,最后通过一个softmax就得到了概率,这样就搭建好了transformer。
在网上[2]看到Transformer好像还有增强版,分别为Universal Transformer和Transformer-XL,不过暂时没有特别多时间去看,没准有机会可以看。
特别地,在发博客的时候刚好看到了kdd-cup 2021的结果,Transformer在图的表示学习上,居然吊打了一众图神经网络,该作法为Graphormer,我感觉我可以先去读读GCN, GAT之类的再仔细了解一下图神经网络的基础,然后有机会的话再去看看这些Graphormer这个神奇的模型。
在网上找到了一段理得很清楚的Transformer的源码,而且有很清楚的讲解,有空可以看看[3]。
[1] https://arxiv.org/abs/1706.03762
[2] https://zhuanlan.zhihu.com/p/85612521
[3] https://blog.csdn.net/qq_18310041/article/details/95787616