【参考:【手撕Self-Attention】self-Attention的numpy实现和pytorch实现_顾道长生’的博客-CSDN博客】
【参考:Self-Attention 原理与代码实现_DonngZH的博客-CSDN博客】
略有修改
from math import sqrt
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
# input : batch_size * seq_len * input_dim
# q : batch_size * input_dim * dim_k
# k : batch_size * input_dim * dim_k
# v : batch_size * input_dim * dim_v
def __init__(self, input_dim, dim_k, dim_v):
super(SelfAttention, self).__init__()
self.dim_q = dim_k # 一般默认 Q=K
self.dim_k = dim_k
self.dim_v = dim_v
#定义线性变换函数
self.linear_q = nn.Linear(input_dim, dim_k, bias=False)
self.linear_k = nn.Linear(input_dim, dim_k, bias=False)
self.linear_v = nn.Linear(input_dim, dim_v, bias=False)
self._norm_fact = 1 / sqrt(dim_k)
def forward(self, x):
# x: batch_size, seq_len, input_dim
q = self.linear_q(x) # batch_size, seq_len, dim_k
k = self.linear_k(x) # batch_size, seq_len, dim_k
v = self.linear_v(x) # batch_size, seq_len, dim_v
#q*k的转置 并*开根号后的dk
dist = torch.bmm(q, k.transpose(1, 2)) * self._norm_fact # batch_size, seq_len, seq_len
#归一化获得attention的相关系数 对每个字求sofmax 也就是每一行
dist = torch.softmax(dist, dim=-1) # batch_size, seq_len, seq_len
#attention系数和v相乘,获得最终的得分
att = torch.bmm(dist, v)
return att
torch.manual_seed(0)
batch_size = 3
seq_len = 2
input_dim = 4
# seq_len 有多少个字 input_dim 一个字由多少维数据表示
X = torch.randn(batch_size, seq_len, input_dim)
X.shape
X
torch.Size([3, 2, 4])
tensor([[[-1.1258, -1.1524, -0.2506, -0.4339],
[ 0.8487, 0.6920, -0.3160, -2.1152]],
[[ 0.4681, -0.1577, 1.4437, 0.2660],
[ 0.1665, 0.8744, -0.1435, -0.1116]],
[[ 0.9318, 1.2590, 2.0050, 0.0537],
[ 0.6181, -0.4128, -0.8411, -2.3160]]])
dim_k=3 # 人为设置的
dim_v=3
self_attention = SelfAttention(input_dim, dim_k, dim_v)
res = self_attention(X)
res
tensor([[[-0.1637, 0.2921, -0.7867],
[-0.1733, 0.2150, -0.8028]],
[[ 0.0750, 0.0588, 0.5038],
[ 0.0887, -0.0084, 0.5341]],
[[ 0.0629, 0.4331, 0.0502],
[ 0.1734, 0.2077, 1.1355]]], grad_fn=<BmmBackward0>)