注意力机制是注意力计算规则能够应用的深度学习网络的载体, 同时包括一些必要的全连接层以及相关张量处理, 使其与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.
NLP领域中, 当前的注意力机制大多数应用于seq2seq架构, 即编码器和解码器模型.
注意力机制实现步骤
第一步: 根据注意力计算规则, 对Q,K,V进行相应的计算.
第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接, 如果是转置点积, 一般是自注意力, Q与V相同, 则不需要进行与Q的拼接.
第三步: 最后为了使整个attention机制按照指定尺寸输出, 使用线性层作用在第二步的结果上做一个线性变换, 得到最终对Q的注意力表示.
实现代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attn(nn.Module):
def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
"""初始化函数中的参数有5个, query_size代表query的最后一维大小
key_size代表key的最后一维大小, value_size1代表value的导数第二维大小,
value = (1, value_size1, value_size2)
value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""
super(Attn, self).__init__()
# 将以下参数传入类中
self.query_size = query_size
self.key_size = key_size
self.value_size1 = value_size1
self.value_size2 = value_size2
self.output_size = output_size
# 初始化注意力机制实现第一步中需要的线性层.
self.attn = nn.Linear(self.query_size + self.key_size, value_size1)
# 初始化注意力机制实现第三步中需要的线性层.
self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)
def forward(self, Q, K, V):
"""forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""
# 第一步, 按照计算规则进行计算,
# 我们采用常见的第一种计算规则
# 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果
attn_weights = F.softmax(
self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)
# 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算,
# 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算
attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
# 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法,
# 需要将Q与第一步的计算结果再进行拼接
output = torch.cat((Q[0], attn_applied[0]), 1)
# 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
# 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度
output = self.attn_combine(output).unsqueeze(0)
return output, attn_weights
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)
out = attn(Q, K ,V)
print(out[0])
print(out[1])
结果如下:
tensor([[[-0.2961, -0.0948, 0.4384, -0.4684, 0.0987, 0.3926, 0.2671,
1.0258, 0.2068, -0.8418, 0.1220, -0.3244, -0.8128, 0.2292,
0.6818, -0.3369, -0.2666, 0.0036, 0.0643, -0.6318, -0.0867,
0.6521, -0.3778, -0.2478, -0.1729, 0.9106, 0.2469, 0.1512,
0.0736, 0.2501, 0.9162, -0.5796, 0.1865, 0.0234, -0.0553,
0.2651, -0.5230, -0.3136, 0.2308, 0.5429, -0.3149, -0.1805,
0.1518, -0.0573, -0.2517, -0.1196, 0.0647, 0.6827, -0.1228,
-0.2044, 0.0298, 0.2147, -0.3879, -0.0771, -0.1359, -0.1912,
-0.4390, 0.4078, 0.0616, 0.1442, 0.1604, -0.3253, -0.1718,
0.2007]]], grad_fn=)
tensor([[0.0275, 0.0272, 0.0259, 0.0258, 0.0529, 0.0228, 0.0382, 0.0111, 0.0544,
0.0352, 0.0188, 0.0241, 0.0375, 0.0172, 0.0194, 0.0528, 0.0124, 0.0263,
0.0811, 0.0194, 0.0238, 0.0553, 0.0232, 0.0468, 0.0183, 0.0193, 0.0075,
0.0193, 0.0382, 0.0188, 0.0362, 0.0636]])