为什么 dot-product attention 需要被 scaled?

前言

注意力机制也有很多种类,不同的注意力机制对应着不同的对齐分数(alignment score)计算方式。有关注意力机制的总结,大家可以看看这篇博客:Attention? Attention!

在 Attention Is All You Need 这篇论文中,有提到两种较为常见的注意力机制:additive attention 和 dot-product attention。并讨论到,当 query 和 key 向量维度 d k d_k dk 较小时,这两种注意力机制效果相当,但当 d k d_k dk 较大时,additive attention 要优于 dot-product attention. 但是 dot-product attention 在计算方面更具有优势。为了利用 dot-product attention 的优势且消除当 d k d_k dk 较大时 dot-product attention 的不足,原文采用 scaled dot-product attention。

正文

那造成这种情况(但当 d k d_k dk 较大时,additive attention 要优于 dot-product attention)的原因是什么?下面是原论文中的解释(当 d k d_k dk 较大时,向量内积的值也会容易变得很大,这时 softmax 函数的梯度会非常的小)。

We suspect that for large values of d k d_k dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely samll gradients.

我们知道,计算完各个 key 的对齐分数后需要将所有 key 的对齐分数输入到 s o f t m a x softmax softmax 激活函数中,得到规范化的注意力权重。

dot-product attention 中的对齐分数的计算公式为:
s c o r e ( q , k ) = q T k score(q, k) = q^T k score(q,k)=qTk

先解释:为什么当 d k d_k dk 较大时,向量内积容易取很大的值(借用原论文的注释)

假设 query 和 key 向量中的元素都是相互独立的均值为 0,方差为 1 的随机变量,那么这两个向量的内积 q T k = ∑ i = 1 d k q i k i q^T k = \sum_{i=1}^{d_k} q_ik_i qTk=i=1dkqiki 的均值为 0,而方差为 d k d_k dk.

证明:
已知 E [ q i ] = E [ k i ] = 0 ,  Var ( q i ) = Var ( k i ) = 1 \text{E}[q_i] = \text{E}[k_i] = 0,\ \text{Var}(q_i)=\text{Var}(k_i)=1 E[qi]=E[ki]=0, Var(qi)=Var(ki)=1.

由于 q i q_i qi k i k_i ki 相互独立,则两者的协方差为 0:
Cov ( q i , k i ) = E [ ( q i − E [ q i ] ) ( k i − E [ k i ] ) ] = E [ q i k i ] − E [ q i ] E [ k i ] = 0 \begin{aligned} \text{Cov}(q_i,k_i) &= \text{E}\left[\left(q_i-\text{E}[q_i]\right)\left(k_i-\text{E}[k_i]\right)\right] \\ &= \text{E}[q_ik_i] - \text{E}[q_i] \text{E}[k_i] \\ &= 0 \end{aligned} Cov(qi,ki)=E[(qiE[qi])(kiE[ki])]=E[qiki]E[qi]E[ki]=0
故得 E [ q i k i ] = E [ q i ] E [ k i ] = 0 \text{E}[q_ik_i] = \text{E}[q_i] \text{E}[k_i] = 0 E[qiki]=E[qi]E[ki]=0.

对于方差,有:
Var ( q i ) = E [ q i 2 ] − ( E [ q i ] ) 2 = E [ q i 2 ] = 1 Var ( k i ) = E [ k i 2 ] = 1 \begin{aligned} \text{Var}(q_i) &= \text{E}[q_i^2] - (\text{E}[q_i])^2\\ &= \text{E}[q_i^2] \\ &= 1 \\ \text{Var}(k_i) &= \text{E}[k_i^2] = 1 \end{aligned} Var(qi)Var(ki)=E[qi2](E[qi])2=E[qi2]=1=E[ki2]=1
故得:
Var ( q i k i ) = E [ ( q i k i ) 2 ] − ( E [ q i k i ] ) 2 = E [ q i 2 ] E [ k i 2 ] − ( E [ q i ] E [ k i ] ) 2 = Var ( q i ) Var ( k i ) = 1 \begin{aligned} \text{Var}(q_ik_i) &= \text{E}[(q_ik_i)^2] - (\text{E}[q_ik_i])^2 \\ &= \text{E}[q_i^2]\text{E}[k_i^2] - (\text{E}[q_i] \text{E}[k_i])^2 \\ & = \text{Var}(q_i)\text{Var}(k_i) \\ & = 1 \end{aligned} Var(qiki)=E[(qiki)2](E[qiki])2=E[qi2]E[ki2](E[qi]E[ki])2=Var(qi)Var(ki)=1
由于对于两个相互独立的随机变量有如下定义:
E [ X + Y ] = E [ X ] + E [ Y ] Var(X+Y) = Var(X) + Var(Y) + 2 Cov ( X , Y )     = Var(X) + Var(Y) \begin{aligned} &\text{E}[X+Y] = \text{E}[X] +\text{E}[Y]\\ &\text{Var(X+Y)} = \text{Var(X)} + \text{Var(Y)} + 2\text{Cov}(X,Y) \\ &\qquad \qquad \ \ \ =\text{Var(X)} + \text{Var(Y)} \end{aligned} E[X+Y]=E[X]+E[]Var(X+Y)=Var(X)+Var(Y)+2Cov(X,Y)   =Var(X)+Var(Y)
综上,可得:
E [ q T k ] = ∑ i = 1 d k E [ q i k i ] = 0 Var ( q T k ) = ∑ i = 1 d k Var ( q i k i ) = d k \begin{aligned} &\text{E}[q^T k ] = \sum_{i=1}^{d_k} \text{E}[q_ik_i] = 0\\ &\text{Var}(q^T k) = \sum_{i=1}^{d_k} \text{Var}(q_ik_i) = d_k \end{aligned} E[qTk]=i=1dkE[qiki]=0Var(qTk)=i=1dkVar(qiki)=dk
所以,可以看出,当 d k d_k dk 较大时, q T k q^Tk qTk 的方差较大,不同的 key 与同一个 query 算出的对齐分数可能会相差很大,有的远大于 0,有的则远小于 0.

再解释:向量内积的值(对齐分数)较大时,softmax 函数梯度很小

先介绍一下 softmax 函数:

s o f t m a x softmax softmax 函数是 logistic (或 sigmoid)函数在多类问题上的引申(有关于 sigmoid 函数的信息可查看我的另一篇博客),记为 S S S,其公式为:
S ( x i ) = e x i ∑ j = 0 n e x j S(x_i) = \frac{e^{x_i}}{\sum_{j=0}^n e^{x_j}} S(xi)=j=0nexjexi
S ( x i ) S(x_i) S(xi) 求偏导,可得:
∂ ∂ x i S ( x i ) = S ( x i ) ( 1 − S ( x i ) ) ∂ ∂ x j S ( x i ) = − S ( x i ) S ( x j ) \begin{aligned} \frac{\partial}{\partial x_i} S(x_i) &= S(x_i)(1-S(x_i)) \\ \frac{\partial}{\partial x_j} S(x_i) &= -S(x_i)S(x_j) \end{aligned} xiS(xi)xjS(xi)=S(xi)(1S(xi))=S(xi)S(xj)
从上面的结果可以看出:

  • x i x_i xi 相对于其他的 x j ( j ≠ i ) x_j(j \neq i) xj(j=i) 特别大时, S ( x i ) S(x_i) S(xi) 趋近于 1,则 ∂ ∂ x i S ( x i ) \frac{\partial}{\partial x_i} S(x_i) xiS(xi) ∂ ∂ x i S ( x j ) \frac{\partial}{\partial x_i} S(x_j) xiS(xj) 都趋近于 0.
  • x i x_i xi 相对较小时, S ( x i ) S(x_i) S(xi) 趋近于 0,则 ∂ ∂ x i S ( x i ) \frac{\partial}{\partial x_i} S(x_i) xiS(xi) ∂ ∂ x i S ( x j ) \frac{\partial}{\partial x_i} S(x_j) xiS(xj) 也都趋近于 0.

也就是, x i x_i xi 趋于 0 或 1 时,上述的两种偏导数都趋于零

现在,我们就可以把这里的 x i x_i xi 替换成前一部分讲到的 query 和 key 向量的内积 q T k q^T k qTk 了。

在前一部分我们有得出结论:当 d k d_k dk 较大时, q T k q^Tk qTk 的方差较大,不同的 key 与同一个 query 算出的对齐分数可能会相差很大,有的远大于 0,有的则远小于 0.

所以,当 d k d_k dk 较大时,很有可能存在某个 key,其与 query 计算出来的对齐分数远大于其他的 key 与该 query 算出的对齐分数。这时, s o f t m a x softmax softmax 函数对各个 q T k q^Tk qTk 的偏导数都趋于 0.

其结果就是, s o f t m a x softmax softmax 函数梯度过低(趋于零),使得模型误差反向传播(back-propagation)经过 s o f t m a x softmax softmax 函数后无法继续传播到模型前面部分的参数上,造成这些参数无法得到更新,最终影响模型的训练效率。

那么如何消除如上 dot-product attention 的问题呢?一种方法就是论文中的对 dot-product attention 进行缩放(除以 d k \sqrt{d_k} dk ),获得 scaled dot-product attention。其对齐分数的计算公式为:
s c o r e ( q , k ) = q T k d k score(q, k) = \frac{q^T k}{\sqrt{d_k}} score(q,k)=dk qTk
根据方差的计算法则: Var ( k x ) = k 2 Var ( x ) \text{Var}(kx) = k^2\text{Var}(x) Var(kx)=k2Var(x),可知缩放后, s c o r e ( q , k ) score(q,k) score(q,k) 的方差由原来的 d k d_k dk 缩小到了 1. 这就消除了 dot-product attention 在 d k d_k dk 较大时遇到的问题。这时,softmax 函数的梯度就不容易趋近于零了。

这就是为什么 dot-product attention 需要被 scaled.

总结

本博客基于随机变量的期望和方差以及 s o f t m a x softmax softmax 函数的性质,详细说明了——为什么 dot-product attention 需要被 scaled.

参考源

  • Attention Is All You Need
  • Attention? Attention!

推荐资源(Transformer 相关)

  • The Illustrated Transformer(概念上)
  • The Annotated Transformer(代码实现上)

你可能感兴趣的:(神经网络)