基于pytorch的自注意力机制实现

基于pytorch的自注意力机制实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, in_feat):
        super(Net, self).__init__()
        self.line1 = nn.Linear(in_feat, 128) 
        #此处需要对输入的维度进行变换,变换为自注意力的网络的输入进行线性转换
       
        #=========================以下是pytorch实现自注意力机制的核心代码
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=128, nhead=8)
        #d_model必须保证是nhead的整倍数。
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2) 
        #num_layers网络层数决定了该结构重复的次数
        #=========================
        
        self.line2 = nn.Linear(128, 1) 
        #自注意力网络输出输入的维度相同,需要根据不同的用途进行线性转换

    def forward(self, x):
        x = F.relu(self.line1(x))
        x = self.transformer_encoder(x)
        x = self.line2(x)
        return x

你可能感兴趣的:(学习笔记,pytorch,深度学习,python)