分类目录:《深入理解深度学习》总目录
《深入理解深度学习——注意力机制(Attention Mechanism):基础知识》介绍了框架下的注意力机制的主要成分: 查询(自主提示)和键(非自主提示)之间的交互形成了注意力汇聚,注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。 本文将介绍注意力汇聚的更多细节, 以便从宏观上了解注意力机制在实践中的运作方式。 具体来说,1964年提出的Nadaraya-Watson核回归模型是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。
简单起见,考虑下面这个回归问题: 给定的成对的数据集 { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋯ , ( x n , y n ) } \{(x_1, y_1), (x_2, y_2), \cdots, (x_n, y_n)\} {(x1,y1),(x2,y2),⋯,(xn,yn)}, 如何学习 f f f来预测任意新输入 s s s的输出 y ^ = f ( x ) \hat{y}=f(x) y^=f(x)?根据下面的非线性函数生成一个人工数据集, 其中加入的噪声项为 ϵ \epsilon ϵ:
y i = 2 sin ( x i ) + x i 0.8 + ϵ y_i = 2\sin(x_i) + x_i^{0.8} + \epsilon yi=2sin(xi)+xi0.8+ϵ
先使用最简单的估计器来解决回归问题。 基于平均汇聚来计算所有训练样本输出值的平均值:
f ( x ) = 1 n ∑ i = 1 n y i f(x)=\frac{1}{n}\sum_{i=1}^ny_i f(x)=n1i=1∑nyi
但这个估计器的预测函数值和真实函数值相差很大。
显然,平均汇聚忽略了输入。 于是Nadaraya和Watson提出了一个更好的想法, 根据输入的位置对输出进行加权:
f ( x ) = ∑ i = 1 n K ( x − x i ) ∑ j = 1 n K ( x − x j ) y i f(x) = \sum_{i=1}^n\frac{K(x - x_i)}{ \sum_{j=1}^nK(x - x_j)}y_i f(x)=i=1∑n∑j=1nK(x−xj)K(x−xi)yi
其中 K K K是核(kernel)。 上式所描述的估计器被称为 Nadaraya-Watson核回归(Nadaraya-Watson Kernel Regression)。 这里不会深入讨论核函数的细节, 但受此启发, 我们可以将注意力机制框架的角度重写上式, 成为一个更加通用的注意力汇聚(Attention Pooling)公式:
f ( x ) = α ( x , x i ) f(x)=\alpha(x, x_i) f(x)=α(x,xi)
其中 x x x是查询, ( x i , y i ) (x_i, y_i) (xi,yi)是键值对。 比较上式和平均汇聚的公式, 注意力汇聚 y i y_i yi是的加权平均。 将查询 x x x和键 x i x_i xi之间的关系建模为注意力权重(Attention Weight) α ( x , x i ) \alpha(x, x_i) α(x,xi), 如上式所示, 这个权重将被分配给每一个对应值 y i y_i yi。 对于任何查询,模型在所有键值对注意力权重都是一个有效的概率分布: 它们是非负的,并且总和为1。
为了更好地理解注意力汇聚, 下面考虑一个高斯核(Gaussian Kernel),其定义为:
K ( u ) = 1 2 π exp ( − u 2 2 ) K(u)=\frac{1}{\sqrt{2\pi}}\exp(-\frac{u^2}{2}) K(u)=2π1exp(−2u2)
将高斯核代入上式可以得到:
f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n exp ( − 1 2 ( x − x i ) ) 2 ∑ j = 1 n exp ( − 1 2 ( x − x j ) ) 2 y i = ∑ i = 1 n Softmax ( − 1 2 ( x − x i ) 2 ) y i \begin{aligned} f(x) &= \sum_{i=1}^n\alpha(x, x_i)y_i\\ &= \sum_{i=1}^n\frac{\exp(-\frac{1}{2}(x - x_i))^2}{ \sum_{j=1}^n\exp(-\frac{1}{2}(x - x_j))^2}y_i\\ &= \sum_{i=1}^n\text{Softmax}(-\frac{1}{2}(x - x_i)^2)y_i \end{aligned} f(x)=i=1∑nα(x,xi)yi=i=1∑n∑j=1nexp(−21(x−xj))2exp(−21(x−xi))2yi=i=1∑nSoftmax(−21(x−xi)2)yi
在上式中, 如果一个键越是接近给定的查询, 那么分配给这个键对应值的注意力权重就会越大, 也就“获得了更多的注意力”。值得注意的是,Nadaraya-Watson核回归是一个非参数模型。 因此, 上式是非参数的注意力汇聚(Nonparametric Attention Pooling)模型。 现在来观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。
非参数的Nadaraya-Watson核回归具有一致性(Consistency)的优点: 如果有足够的数据,此模型会收敛到最优结果。 尽管如此,我们还是可以轻松地将可学习的参数集成到注意力汇聚中。例如,与上式略有不同, 在下面的查询 x x x和键# x i x_i xi之间的距离乘以可学习参数:
f ( x ) = ∑ i = 1 n α ( x , x i ) y i = ∑ i = 1 n exp ( ( − 1 2 ( x − x i ) ) w ) 2 ∑ j = 1 n exp ( − 1 2 ( x − x i ) ) w ) 2 y i = ∑ i = 1 n Softmax ( − 1 2 ( ( x − x i ) w ) 2 ) y i \begin{aligned} f(x) &= \sum_{i=1}^n\alpha(x, x_i)y_i\\ &= \sum_{i=1}^n\frac{\exp((-\frac{1}{2}(x - x_i))w)^2}{ \sum_{j=1}^n\exp(-\frac{1}{2}(x - x_i))w)^2}y_i\\ &= \sum_{i=1}^n\text{Softmax}(-\frac{1}{2}((x - x_i)w)^2)y_i \end{aligned} f(x)=i=1∑nα(x,xi)yi=i=1∑n∑j=1nexp(−21(x−xi))w)2exp((−21(x−xi))w)2yi=i=1∑nSoftmax(−21((x−xi)w)2)yi
参考文献:
[1] Lecun Y, Bengio Y, Hinton G. Deep learning[J]. Nature, 2015
[2] Aston Zhang, Zack C. Lipton, Mu Li, Alex J. Smola. Dive Into Deep Learning[J]. arXiv preprint arXiv:2106.11342, 2021.