attention注意力机制【对应图的代码讲解】

文章目录

      • 题目
      • 注意力机制三步式+分步代码讲解
      • 运行结果

题目

'''
Description: attention注意力机制
Autor: 365JHWZGo
Date: 2021-12-14 17:06:11
LastEditors: 365JHWZGo
LastEditTime: 2021-12-14 22:23:54
'''

注意力机制三步式+分步代码讲解

导入库

import torch 
import torch.nn as nn
import torch.nn.functional as F

Attn

class Attn(nn.Module):
    def __init__(self,query_size,key_size,value_size1,value_size2):
        super(Attn,self).__init__()
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        
        self.attn = nn.Linear(self.query_size+self.key_size,value_size1)
    
    def forward(self,q,k,v):
        
        # attn_weights=(1,32)
        attn_weights = F.softmax(self.attn(torch.concat((q[0],k[0]),1)),dim=1)
        # attn_weights.unsqueeze(0)=(1,1,32)
        # v=(1,32,64)
        # attn_applied=(1,1,64)
        output = torch.bmm(attn_weights.unsqueeze(0),v)
        
        return output,attn_weights

attn函数是将合成【Query|Key】,进行列合并
f ( Q , K ) = W a [ Q , K ] f(Q,K) = W_a[Q,K] f(Q,K)=Wa[Q,K]

attn_weights的结果对应于a1,a2,a3…
attention注意力机制【对应图的代码讲解】_第1张图片
output是计算Attention Value,bmm相当于a1value1+a2value2+…【矩阵乘法】
attention注意力机制【对应图的代码讲解】_第2张图片

if __name__ == "__main__":
    query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    
    attn = Attn(query_size, key_size, value_size1, value_size2)
    Q = torch.randn(1,1,32)
    K = torch.randn(1,1,32)
    V = torch.randn(1,32,64)
    out = attn(Q, K ,V)
    print(out[0])
    print(out[1])

运行结果

你可能感兴趣的:(实践中的细节,python,深度学习,pytorch)