mask是Transformer中很重要的一个概念,mask操作的目的有两个:
上面的第一个目的分别对应的是普通的Scaled Dot-Product Attention中的mask操作,而后一个目的对应的是Masked Multi-Head Attention中的Masked
至于这两个mask分别是如何进行的,我们下面来一一讲解。但是首先我们要弄明白multi-attention中的矩阵维度变化,transformer是如何训练和测试的
假设Multi-Head Attention层的输入为 Q ∈ R ( N , T q , d m o d e l ) Q \in \mathbb R^{(N,T_q, d_{model})} Q∈R(N,Tq,dmodel), K ∈ R ( N , T k , d m o d e l ) K \in \mathbb R^{(N,T_k,d_{model})} K∈R(N,Tk,dmodel), V ∈ R ( N , T k , d m o d e l ) V \in \mathbb R^{(N,T_k,d_{model})} V∈R(N,Tk,dmodel)。其中 N N N是batch_size, T q T_q Tq时 Q Q Q的maxlen, T k T_k Tk是 K K K的maxlen, d m o d e l d_{model} dmodel是最初单个词embedding的向量长度
接下来进行的是 h h h次线性变换, h h h实际就是多头注意力的头数,假设第 i i i次线性变换后会得到 Q i ∗ , K i ∗ , V i ∗ Q^*_i,K^*_i,V^*_i Qi∗,Ki∗,Vi∗,变换方式如下:
Q i ∗ = Q W i Q W i Q ∈ R ( d m o d e l , d k ) K i ∗ = Q W i K W i K ∈ R ( d m o d e l , d k ) V i ∗ = V W i V W i V ∈ R ( d m o d e l , d v ) Q^*_i = QW^Q_i \qquad W^Q_i \in \mathbb R^{(d_{model}, d_k)}\\ \\ K^*_i = QW^K_i \qquad W^K_i \in \mathbb R^{(d_{model}, d_k)}\\ \\ V^*_i = VW^V_i \qquad W^V_i \in \mathbb R^{(d_{model}, d_v)}\\ \\ Qi∗=QWiQWiQ∈R(dmodel,dk)Ki∗=QWiKWiK∈R(dmodel,dk)Vi∗=VWiVWiV∈R(dmodel,dv)
因此, Q i ∗ ∈ R ( N , T q , d k ) Q^*_i \in \mathbb R^{(N,T_q,d_k)} Qi∗∈R(N,Tq,dk), K i ∗ ∈ R ( N , T k , d k ) K^*_i \in \mathbb R^{(N,T_k,d_k)} Ki∗∈R(N,Tk,dk), V i ∗ ∈ R ( N , T k , d v ) V^*_i \in \mathbb R^{(N,T_k,d_v)} Vi∗∈R(N,Tk,dv)
接下来将 Q i ∗ , K i ∗ , V i ∗ Q^*_i,K^*_i,V^*_i Qi∗,Ki∗,Vi∗进行attention的操作我们分别来看:
t e m p 1 = Q i ∗ ( K i ∗ ) T d k t e m p 1 ∈ R ( N , T q , T k ) h e a d i = a t t e n t i o n ( Q i ∗ , K i ∗ , V i ∗ ) = s o f t m a x ( t e m p ) V temp_1 = \frac {Q^*_i(K^*_i)^T}{\sqrt d_k} \qquad temp_1 \in \mathbb R^{(N,T_q,T_k)}\\ head_i = attention(Q^*_i, K^*_i,V^*_i) = softmax(temp)V temp1=dkQi∗(Ki∗)Ttemp1∈R(N,Tq,Tk)headi=attention(Qi∗,Ki∗,Vi∗)=softmax(temp)V
因此我们可以看出来 h e a d i ∈ R ( N , T q , d v ) head_i \in \mathbb R^{(N, T_q, d_v)} headi∈R(N,Tq,dv)
然后我们进行的操作是,将 h h h次attention得到的操作连接起来,并乘 W O W^O WO矩阵
t e m p 2 = c o n c a t ( ( h e a d 1 , h e a d 2 , . . . . . . , h e a d h ) , − 1 ) t e m p 2 ∈ R ( N , T q , h d v ) M u l t i − H e a d ( Q , K , V ) = t e m p 2 W O W O ∈ R ( N , h d v , d m o d e l ) temp_2 = concat((head_1,head_2,......,head_h), -1) \qquad temp_2 \in \mathbb R^{(N, T_q, hd_v)} \\ Multi-Head(Q, K, V) = temp_2W^O \qquad W^O \in \mathbb R^{(N, hd_v, d_{model})} temp2=concat((head1,head2,......,headh),−1)temp2∈R(N,Tq,hdv)Multi−Head(Q,K,V)=temp2WOWO∈R(N,hdv,dmodel)
因此最终得到的结果的维度是 R ( N , T q , d m o d e l ) \mathbb R^{(N, T_q, d_{model})} R(N,Tq,dmodel)。你应该会发现,和原本输入的 Q Q Q是同样的维度
我们首先做一个假设,训练的batch大小为1,暂时忽略掉Positional Encoding
一个batch包含代翻译的句子集 X X X和翻译后的句子集(ground_truth) Y Y Y,其中:
X ∈ R ( 1 , T 1 , d m o d e l ) Y ∈ R ( 1 , T 2 , d m o d e l ) X \in \mathbb R^{(1, T_1,d_{model})} \\ Y \in \mathbb R^{(1, T_2,d_{model})} \\ X∈R(1,T1,dmodel)Y∈R(1,T2,dmodel)
T 1 T_1 T1是代翻译句子的最大长度, T 2 T_2 T2是ground_truth的最大长度
联系上图,训练过程的encoder下方的Inputs是 X X X,decoder下方的Outputs是 Y Y Y
最终softmax输出的矩阵维度是 ( 1 , T 2 , v o c a b _ s i z e ) (1,T_2,vocab\_size) (1,T2,vocab_size), T 2 T_2 T2表示的是groud_truth的maxlen,因此矩阵第二维的每一个向量都代表着这个句子中一个词的概率分布,比如 ( 1 , 2 , 3 ) (1,2,3) (1,2,3)代表的是第一个句子的第二个词是词表第三个词的概率
由上也不难理解为啥transformer的训练过程可以并行
因此transformer其实是对groud_truth的每一个词进一次预测,所以其损失函数就可以是交叉熵函数
对于测试样例或者没有ground_truth的样本,Transformer生成句子的模式就和上图差不多(真实的Transformer是由6个encoder和6个decoder组成的)
每一个词的生成都需要输入之前生成的内容,这也是为什么测试和评估过程不是并行的原因
最后,通过上面的内容,解释一下为什么要mask以及是如何操作的
在transformer中mask操作其实由三种,我们可以简单的分为:
我们依次来进行介绍,在此之前我们进行一个假设 T 1 = 4 , T 2 = 4 , d k = 4 , d v = 4 T_1=4,T_2=4,d_k=4,d_v=4 T1=4,T2=4,dk=4,dv=4
我们来看一个维度和单个句子一样的矩阵
[ 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 ] \begin{bmatrix} 1&1&1&1 \\ 1&1&1&1 \\ 0&0&0&0 \\ 0&0&0&0 \end{bmatrix} ⎣⎢⎢⎡1100110011001100⎦⎥⎥⎤
以上就是一个mask矩阵,这个mask矩阵表示的是句子后两个词是padding的内容,所以全都是0。我们在进行attention时不应该将这里考虑进attention,当给mask矩阵为0的对应位置替换一个负很大的值后,相应attention的结果就会趋近为0。
这部分参考了代码transformer的实现,接下来的mask的解释主要是基于该代码
首先我们应该知道,对key进行mask的操作实际上是在attention的softmax之前的,参考代码中的mask矩阵是如下形式的
[ 1 1 0 0 1 1 0 0 1 1 0 0 1 1 0 0 ] \begin{bmatrix} 1&1&0&0\\ 1&1&0&0\\ 1&1&0&0\\ 1&1&0&0\\ \end{bmatrix} ⎣⎢⎢⎡1111111100000000⎦⎥⎥⎤
因为这个mask矩阵起作用的时间点是 Q K T QK^T QKT计算之后,对 Q K T QK^T QKT进行的mask操作,此时由于 K K K经过了转置,所以相应的0的位置也变到了列上面来,于是将 Q K T QK^T QKT的响应位置替换为负的很大值,那么 s o f t m a x ( Q K T ) softmax(QK^T) softmax(QKT)对应位置的attention权重就会接近0
padding_num = -2 ** 32 + 1
masks = tf.sign(tf.reduce_sum(tf.abs(keys), axis=-1)) # (N, T_k),经过这个计算后,若是padding的部分就会变为0
masks = tf.expand_dims(masks, 1) # (N, 1, T_k) 在1的位置上增加一个维度
masks = tf.tile(masks, [1, tf.shape(queries)[1], 1]) # (N, T_q, T_k) tile用于扩展倍数,本句表示二维扩展T_q倍
# ones_like表示的是创建一个同样维度全都是1的矩阵
paddings = tf.ones_like(inputs) * padding_num
# tf.where(input, a, b)表示a中input对应位置维0的不变,其余的替换为b位置的数值
outputs = tf.where(tf.equal(masks, 0), paddings, inputs) # (N, T_q, T_k)
对Q进行的mask操作其实是最简单的了,因为 Q Q Q也存在padding的位置,在进行一个Multi-Head Attention计算后,就使得原来是0的位置不是0,所以attention输出的这些位置也应该为空,所以只需要在attention计算之后把相应的位置替换为0即可
masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q)
masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k)
outputs = inputs*masks
这个稍微复杂一点,我们先看看对应的mask矩阵
[ 1 0 0 0 1 1 0 0 1 1 1 0 1 1 1 1 ] \begin{bmatrix} 1&0&0&0\\ 1&1&0&0\\ 1&1&1&0\\ 1&1&1&1\\ \end{bmatrix} ⎣⎢⎢⎡1111011100110001⎦⎥⎥⎤
这个是在 Q K T QK^T QKT之后,softmax之前的进行的,我们来慢慢分析
Q = K = V = [ s 1 s 2 s 3 s 4 ] Q = K = V = \begin{bmatrix} s_1\\s_2\\s_3\\s_4 \end{bmatrix} Q=K=V=⎣⎢⎢⎡s1s2s3s4⎦⎥⎥⎤
接下来进行计算 Q K T QK^T QKT
Q K T = [ s 1 s 1 T s 1 s 2 T s 1 s 3 T s 1 s 4 T s 2 s 1 T s 2 s 2 T s 2 s 3 T s 2 s 4 T s 3 s 1 T s 3 s 2 T s 3 s 3 T s 3 s 4 T s 4 s 1 T s 4 s 2 T s 4 s 3 T s 4 s 4 T ] QK^T=\begin{bmatrix} s_1s_1^T&s_1s_2^T&s_1s_3^T&s_1s_4^T \\ s_2s_1^T&s_2s_2^T&s_2s_3^T&s_2s_4^T \\ s_3s_1^T&s_3s_2^T&s_3s_3^T&s_3s_4^T \\ s_4s_1^T&s_4s_2^T&s_4s_3^T&s_4s_4^T \\ \end{bmatrix} QKT=⎣⎢⎢⎡s1s1Ts2s1Ts3s1Ts4s1Ts1s2Ts2s2Ts3s2Ts4s2Ts1s3Ts2s3Ts3s3Ts4s3Ts1s4Ts2s4Ts3s4Ts4s4T⎦⎥⎥⎤
mask矩阵作用于 Q K T QK^T QKT上,可以得到(上面程序中用 2 32 + 1 2^{32} + 1 232+1代替 ∞ \infty ∞)
[ s 1 s 1 T − ∞ − ∞ − ∞ s 2 s 1 T s 2 s 2 T − ∞ − ∞ s 3 s 1 T s 3 s 2 T s 3 s 3 T − ∞ s 4 s 1 T s 4 s 2 T s 4 s 3 T s 4 s 4 T ] \begin{bmatrix} s_1s_1^T&-\infty &-\infty&-\infty \\ s_2s_1^T&s_2s_2^T&-\infty&-\infty \\ s_3s_1^T&s_3s_2^T&s_3s_3^T&-\infty \\ s_4s_1^T&s_4s_2^T&s_4s_3^T&s_4s_4^T \\ \end{bmatrix} ⎣⎢⎢⎡s1s1Ts2s1Ts3s1Ts4s1T−∞s2s2Ts3s2Ts4s2T−∞−∞s3s3Ts4s3T−∞−∞−∞s4s4T⎦⎥⎥⎤
这里省略掉除以 d k \sqrt d_k dk,进行softmax可以得到
[ a 11 0 0 0 a 21 a 22 0 0 a 31 a 32 a 33 0 a 41 a 42 a 43 a 44 ] \begin{bmatrix} a_{11}&0&0&0\\ a_{21}&a_{22}&0&0\\ a_{31}&a_{32}&a_{33}&0\\ a_{41}&a_{42}&a_{43}&a_{44} \end{bmatrix} ⎣⎢⎢⎡a11a21a31a410a22a32a4200a33a43000a44⎦⎥⎥⎤
然后和V相乘
[ a 11 s 1 0 0 0 a 21 s 1 a 22 s 2 0 0 a 31 s 1 a 32 s 2 a 33 s 3 0 a 41 s 1 a 42 s 2 a 43 s 3 a 44 s 4 ] \begin{bmatrix} a_{11}s_1&0&0&0\\ a_{21}s_1&a_{22}s_2&0&0\\ a_{31}s_1&a_{32}s_2&a_{33}s_3&0\\ a_{41}s_1&a_{42}s_2&a_{43}s_3&a_{44}s_4 \end{bmatrix} ⎣⎢⎢⎡a11s1a21s1a31s1a41s10a22s2a32s2a42s200a33s3a43s3000a44s4⎦⎥⎥⎤
那么使用 [ a 21 s 1 , a 22 s 2 , 0 , 0 ] [a_{21}s_1 ,a_{22}s_2,0,0] [a21s1,a22s2,0,0]来预测第二个词的概率分布的时候,已经没有了第三个词,第四个词的信息,这就是让训练在生成当前词的时候不会注意到之后的词的原因
diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k)
# 这一句的意思是生成一个上三角矩阵,上三角矩阵用来对decoder的结果进行mask
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
# 将mask为0处全部变为负无穷
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
[参考链接]
参考文档:attention is all you need
参考程序:Kyubyong/transformer
如有错误,敬请指正,欢迎交流