【NLP】Self-Attention理解(Pytorch实现)

【参考:【手撕Self-Attention】self-Attention的numpy实现和pytorch实现_顾道长生’的博客-CSDN博客】
【参考:Self-Attention 原理与代码实现_DonngZH的博客-CSDN博客】

略有修改

from math import sqrt

import torch
import torch.nn as nn


class SelfAttention(nn.Module):
    # input : batch_size * seq_len * input_dim 
    # q : batch_size * input_dim * dim_k
    # k : batch_size * input_dim * dim_k
    # v : batch_size * input_dim * dim_v
    def __init__(self, input_dim, dim_k, dim_v):
        super(SelfAttention, self).__init__()
        self.dim_q = dim_k # 一般默认 Q=K
        self.dim_k = dim_k
        self.dim_v = dim_v

        #定义线性变换函数
        self.linear_q = nn.Linear(input_dim, dim_k, bias=False)
        self.linear_k = nn.Linear(input_dim, dim_k, bias=False)
        self.linear_v = nn.Linear(input_dim, dim_v, bias=False)
        self._norm_fact = 1 / sqrt(dim_k)

    def forward(self, x):
        # x: batch_size, seq_len, input_dim

        q = self.linear_q(x)  # batch_size, seq_len, dim_k
        k = self.linear_k(x)  # batch_size, seq_len, dim_k
        v = self.linear_v(x)  # batch_size, seq_len, dim_v
        #q*k的转置 并*开根号后的dk
        dist = torch.bmm(q, k.transpose(1, 2)) * self._norm_fact  # batch_size, seq_len, seq_len
        #归一化获得attention的相关系数  对每个字求sofmax 也就是每一行
        dist = torch.softmax(dist, dim=-1)  # batch_size, seq_len, seq_len
        #attention系数和v相乘,获得最终的得分
        att = torch.bmm(dist, v)
        return att
torch.manual_seed(0)
batch_size = 3
seq_len = 2
input_dim = 4  
# seq_len 有多少个字 input_dim 一个字由多少维数据表示
X = torch.randn(batch_size, seq_len, input_dim)
X.shape
X
torch.Size([3, 2, 4])
tensor([[[-1.1258, -1.1524, -0.2506, -0.4339],
         [ 0.8487,  0.6920, -0.3160, -2.1152]],

        [[ 0.4681, -0.1577,  1.4437,  0.2660],
         [ 0.1665,  0.8744, -0.1435, -0.1116]],

        [[ 0.9318,  1.2590,  2.0050,  0.0537],
         [ 0.6181, -0.4128, -0.8411, -2.3160]]])
dim_k=3 # 人为设置的
dim_v=3

self_attention = SelfAttention(input_dim, dim_k, dim_v)
res = self_attention(X)
res
tensor([[[-0.1637,  0.2921, -0.7867],
         [-0.1733,  0.2150, -0.8028]],

        [[ 0.0750,  0.0588,  0.5038],
         [ 0.0887, -0.0084,  0.5341]],

        [[ 0.0629,  0.4331,  0.0502],
         [ 0.1734,  0.2077,  1.1355]]], grad_fn=<BmmBackward0>)

你可能感兴趣的:(#,+,自然语言处理,#,+,Pytorch,深度学习,pytorch,自然语言处理,深度学习)