注意力机制 - 注意力汇聚:Nadaraya-Watson核回归

文章目录

  • 注意力汇聚:Nadaraya-Watson核回归
    • 1 - 生成数据集
    • 2 - 平均汇聚
    • 3 - 非参数注意力汇聚
    • 4 - 带参数注意力汇聚
      • 批量矩阵乘法
      • 定义模型
      • 训练
    • 5 - 小结

注意力汇聚:Nadaraya-Watson核回归

框架下的注意力机制的主要成分:查询(自主提示)和键(非自主提示)之间交互形成了注意力汇聚,注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。在本节中,我们将介绍注意力汇聚的更多细节,以便从宏观上了解注意力机制在实践中的运作方式。1964年提出的Nadaraya-Watson核回归模型是⼀个简单但完整的例⼦,可以⽤于演⽰具有注意⼒机制的机器学习

import torch
from torch import nn
from d2l import torch as d2l

1 - 生成数据集

注意力机制 - 注意力汇聚:Nadaraya-Watson核回归_第1张图片

n_train = 50 # 训练样本数
x_train,_ = torch.sort(torch.rand(n_train) * 5) # 排序后的训练样本
def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0,0.5,(n_train,)) # 训练样本的输出
x_test = torch.arange(0,5,0.1) # 测试样本
y_truth = f(x_test) # 测试样本的真实输出
n_test = len(x_test) # 测试样本数
n_test
50

下面的函数将绘制所有的训练样本(样本由圆圈表示),不带噪声项的真实数据生成函数f(标记为“Truth”),以及学习得到的预测函数(标记为“Pred”)

def plot_kernel_reg(y_hat):
    d2l.plot(x_test,[y_truth,y_hat],'x','y',legend=['Truth','Pred'],xlim=[0,5],ylim=[-1,5])
    d2l.plt.plot(x_train,y_train,'o',alpha=0.5);

2 - 平均汇聚

注意力机制 - 注意力汇聚:Nadaraya-Watson核回归_第2张图片

y_hat = torch.repeat_interleave(y_train.mean(),n_test)
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j6TRUXSQ-1662988499736)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107812.svg)]

3 - 非参数注意力汇聚

注意力机制 - 注意力汇聚:Nadaraya-Watson核回归_第3张图片
注意力机制 - 注意力汇聚:Nadaraya-Watson核回归_第4张图片

# X_repeat的形状:(n_test,n_train)
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1,n_train))

# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每⼀⾏都包含着要在给定的每个查询的值(y_train)之间分配的注意⼒权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2,dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights,y_train)
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bHHxRgOe-1662988499737)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107813.svg)]

现在,我们来观察注意力的权重,这里测试数据的输入相当于查询,而训练数据的输入相当于键。因为两个输入都是经过排序的,因此由观察可知,“查询-键”对越接近,注意力汇聚的注意力权重就越高

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                 xlabel='Sorted training inputs',
                 ylabel='Sorted testing inputs')


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-d4R52D1O-1662988499737)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107814.svg)]

4 - 带参数注意力汇聚

注意力机制 - 注意力汇聚:Nadaraya-Watson核回归_第5张图片

批量矩阵乘法

注意力机制 - 注意力汇聚:Nadaraya-Watson核回归_第6张图片

X = torch.ones((2,1,4))
Y = torch.ones((2,4,6))

torch.bmm(X,Y).shape
torch.Size([2, 1, 6])

在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值

weights = torch.ones((2,10)) * 0.1
values = torch.arange(20.0).reshape((2,10))
torch.bmm(weights.unsqueeze(1),values.unsqueeze(-1))
tensor([[[ 4.5000]],

        [[14.5000]]])

定义模型

基于带参数的注意力汇聚,使用小批量矩阵乘法,定义Nadaraya-Watson核回归的带参数版本为:

class NWKernelRegression(nn.Module):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,),requires_grad = True))
        
    def forward(self,queries,keys,values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1,keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 /2 ,dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                        values.unsqueeze(-1)).reshape(-1)

训练

接下来,将训练数据集变换为键和值用于训练注意力模型。在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,从而得到其对应的预测输出

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train,1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train,1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))

训练带参数的注意力汇聚模型时,使用平方损失函数和随机梯度下降

net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(),lr=0.5)
animator = d2l.Animator(xlabel='epoch',ylabel='loss',xlim=[1,5])

for epoch in range(5):
    trainer.zero_grad()
    l = loss(net(x_train,keys,values),y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {epoch + 1}, loss {float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5nsQsico-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107815.svg)]

如下所示,训练完带参数的注意力汇聚模型后,我们发现:在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的平滑

# keys的形状:(n_test,n_train),每⼀⾏包含着相同的训练输⼊(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V9IQPxat-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107816.svg)]

为什么新的模型更不平滑了呢?我们看一下输出结果的绘制图:与非参数的注意力汇聚模型相比,带参数的模型加入可学习的参数后,曲线在注意力权重较大的区域变得更不平滑

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                xlabel='Sorted training inputs',
                ylabel='Sorted testing inputs')


[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rRT7kL7O-1662988499738)(https://yingziimage.oss-cn-beijing.aliyuncs.com/img/202209122107817.svg)]

5 - 小结

  • Nadaraya-Watson核回归时具有注意力机制的机器学习范例
  • Nadaraya-Watson核回归的注意⼒汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于你将值所对应的键核查询作为输入的函数
  • 注意力汇聚可以分为非参数型核带参数型

你可能感兴趣的:(深度学习,深度学习,注意力机制)