注意力机制也有很多种类,不同的注意力机制对应着不同的对齐分数(alignment score)计算方式。有关注意力机制的总结,大家可以看看这篇博客:Attention? Attention!
在 Attention Is All You Need 这篇论文中,有提到两种较为常见的注意力机制:additive attention 和 dot-product attention。并讨论到,当 query 和 key 向量维度 d k d_k dk 较小时,这两种注意力机制效果相当,但当 d k d_k dk 较大时,additive attention 要优于 dot-product attention. 但是 dot-product attention 在计算方面更具有优势。为了利用 dot-product attention 的优势且消除当 d k d_k dk 较大时 dot-product attention 的不足,原文采用 scaled dot-product attention。
那造成这种情况(但当 d k d_k dk 较大时,additive attention 要优于 dot-product attention)的原因是什么?下面是原论文中的解释(当 d k d_k dk 较大时,向量内积的值也会容易变得很大,这时 softmax 函数的梯度会非常的小)。
We suspect that for large values of d k d_k dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely samll gradients.
我们知道,计算完各个 key 的对齐分数后需要将所有 key 的对齐分数输入到 s o f t m a x softmax softmax 激活函数中,得到规范化的注意力权重。
dot-product attention 中的对齐分数的计算公式为:
s c o r e ( q , k ) = q T k score(q, k) = q^T k score(q,k)=qTk
假设 query 和 key 向量中的元素都是相互独立的均值为 0,方差为 1 的随机变量,那么这两个向量的内积 q T k = ∑ i = 1 d k q i k i q^T k = \sum_{i=1}^{d_k} q_ik_i qTk=∑i=1dkqiki 的均值为 0,而方差为 d k d_k dk.
证明:
已知 E [ q i ] = E [ k i ] = 0 , Var ( q i ) = Var ( k i ) = 1 \text{E}[q_i] = \text{E}[k_i] = 0,\ \text{Var}(q_i)=\text{Var}(k_i)=1 E[qi]=E[ki]=0, Var(qi)=Var(ki)=1.
由于 q i q_i qi 与 k i k_i ki 相互独立,则两者的协方差为 0:
Cov ( q i , k i ) = E [ ( q i − E [ q i ] ) ( k i − E [ k i ] ) ] = E [ q i k i ] − E [ q i ] E [ k i ] = 0 \begin{aligned} \text{Cov}(q_i,k_i) &= \text{E}\left[\left(q_i-\text{E}[q_i]\right)\left(k_i-\text{E}[k_i]\right)\right] \\ &= \text{E}[q_ik_i] - \text{E}[q_i] \text{E}[k_i] \\ &= 0 \end{aligned} Cov(qi,ki)=E[(qi−E[qi])(ki−E[ki])]=E[qiki]−E[qi]E[ki]=0
故得 E [ q i k i ] = E [ q i ] E [ k i ] = 0 \text{E}[q_ik_i] = \text{E}[q_i] \text{E}[k_i] = 0 E[qiki]=E[qi]E[ki]=0.
对于方差,有:
Var ( q i ) = E [ q i 2 ] − ( E [ q i ] ) 2 = E [ q i 2 ] = 1 Var ( k i ) = E [ k i 2 ] = 1 \begin{aligned} \text{Var}(q_i) &= \text{E}[q_i^2] - (\text{E}[q_i])^2\\ &= \text{E}[q_i^2] \\ &= 1 \\ \text{Var}(k_i) &= \text{E}[k_i^2] = 1 \end{aligned} Var(qi)Var(ki)=E[qi2]−(E[qi])2=E[qi2]=1=E[ki2]=1
故得:
Var ( q i k i ) = E [ ( q i k i ) 2 ] − ( E [ q i k i ] ) 2 = E [ q i 2 ] E [ k i 2 ] − ( E [ q i ] E [ k i ] ) 2 = Var ( q i ) Var ( k i ) = 1 \begin{aligned} \text{Var}(q_ik_i) &= \text{E}[(q_ik_i)^2] - (\text{E}[q_ik_i])^2 \\ &= \text{E}[q_i^2]\text{E}[k_i^2] - (\text{E}[q_i] \text{E}[k_i])^2 \\ & = \text{Var}(q_i)\text{Var}(k_i) \\ & = 1 \end{aligned} Var(qiki)=E[(qiki)2]−(E[qiki])2=E[qi2]E[ki2]−(E[qi]E[ki])2=Var(qi)Var(ki)=1
由于对于两个相互独立的随机变量有如下定义:
E [ X + Y ] = E [ X ] + E [ Y ] Var(X+Y) = Var(X) + Var(Y) + 2 Cov ( X , Y ) = Var(X) + Var(Y) \begin{aligned} &\text{E}[X+Y] = \text{E}[X] +\text{E}[Y]\\ &\text{Var(X+Y)} = \text{Var(X)} + \text{Var(Y)} + 2\text{Cov}(X,Y) \\ &\qquad \qquad \ \ \ =\text{Var(X)} + \text{Var(Y)} \end{aligned} E[X+Y]=E[X]+E[Y]Var(X+Y)=Var(X)+Var(Y)+2Cov(X,Y) =Var(X)+Var(Y)
综上,可得:
E [ q T k ] = ∑ i = 1 d k E [ q i k i ] = 0 Var ( q T k ) = ∑ i = 1 d k Var ( q i k i ) = d k \begin{aligned} &\text{E}[q^T k ] = \sum_{i=1}^{d_k} \text{E}[q_ik_i] = 0\\ &\text{Var}(q^T k) = \sum_{i=1}^{d_k} \text{Var}(q_ik_i) = d_k \end{aligned} E[qTk]=i=1∑dkE[qiki]=0Var(qTk)=i=1∑dkVar(qiki)=dk
所以,可以看出,当 d k d_k dk 较大时, q T k q^Tk qTk 的方差较大,不同的 key 与同一个 query 算出的对齐分数可能会相差很大,有的远大于 0,有的则远小于 0.
先介绍一下 softmax 函数:
s o f t m a x softmax softmax 函数是 logistic (或 sigmoid)函数在多类问题上的引申(有关于 sigmoid 函数的信息可查看我的另一篇博客),记为 S S S,其公式为:
S ( x i ) = e x i ∑ j = 0 n e x j S(x_i) = \frac{e^{x_i}}{\sum_{j=0}^n e^{x_j}} S(xi)=∑j=0nexjexi
对 S ( x i ) S(x_i) S(xi) 求偏导,可得:
∂ ∂ x i S ( x i ) = S ( x i ) ( 1 − S ( x i ) ) ∂ ∂ x j S ( x i ) = − S ( x i ) S ( x j ) \begin{aligned} \frac{\partial}{\partial x_i} S(x_i) &= S(x_i)(1-S(x_i)) \\ \frac{\partial}{\partial x_j} S(x_i) &= -S(x_i)S(x_j) \end{aligned} ∂xi∂S(xi)∂xj∂S(xi)=S(xi)(1−S(xi))=−S(xi)S(xj)
从上面的结果可以看出:
也就是,当 x i x_i xi 趋于 0 或 1 时,上述的两种偏导数都趋于零。
现在,我们就可以把这里的 x i x_i xi 替换成前一部分讲到的 query 和 key 向量的内积 q T k q^T k qTk 了。
在前一部分我们有得出结论:当 d k d_k dk 较大时, q T k q^Tk qTk 的方差较大,不同的 key 与同一个 query 算出的对齐分数可能会相差很大,有的远大于 0,有的则远小于 0.
所以,当 d k d_k dk 较大时,很有可能存在某个 key,其与 query 计算出来的对齐分数远大于其他的 key 与该 query 算出的对齐分数。这时, s o f t m a x softmax softmax 函数对各个 q T k q^Tk qTk 的偏导数都趋于 0.
其结果就是, s o f t m a x softmax softmax 函数梯度过低(趋于零),使得模型误差反向传播(back-propagation)经过 s o f t m a x softmax softmax 函数后无法继续传播到模型前面部分的参数上,造成这些参数无法得到更新,最终影响模型的训练效率。
那么如何消除如上 dot-product attention 的问题呢?一种方法就是论文中的对 dot-product attention 进行缩放(除以 d k \sqrt{d_k} dk),获得 scaled dot-product attention。其对齐分数的计算公式为:
s c o r e ( q , k ) = q T k d k score(q, k) = \frac{q^T k}{\sqrt{d_k}} score(q,k)=dkqTk
根据方差的计算法则: Var ( k x ) = k 2 Var ( x ) \text{Var}(kx) = k^2\text{Var}(x) Var(kx)=k2Var(x),可知缩放后, s c o r e ( q , k ) score(q,k) score(q,k) 的方差由原来的 d k d_k dk 缩小到了 1. 这就消除了 dot-product attention 在 d k d_k dk 较大时遇到的问题。这时,softmax 函数的梯度就不容易趋近于零了。
这就是为什么 dot-product attention 需要被 scaled.
本博客基于随机变量的期望和方差以及 s o f t m a x softmax softmax 函数的性质,详细说明了——为什么 dot-product attention 需要被 scaled.