pytorch实现transformer添加位置信息构建分类模型

transformer模型现在已经广泛应用于NLP、CV等各种场景并且取得很好的效果,在此记录一下如何使用pytorch来构建Transformer模型进行分类,具体代码如下:

import torch
import numpy as np
import torch.nn as nn

from configs.config import opt
class trans_model(nn.Module):
    def __init__(self, d_model,nhead,num_layers):
        super(trans_model, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.linear = nn.Linear(d_model,2) # 此处的2代表最终输出2维
        self.num_labels = 2

    def forward(self, inputs):
        inputs += PositionalEncoding(max_seq_len=128,embed_dim=512,inputs=inputs)
        trans_out = self.transformer_encoder(inputs)
        linear_out = self.linear(trans_out)

        return linear_out


def PositionalEncoding(max_seq_len, embed_dim,inputs):
    positional_encoding = np.array([[
        [np.sin(pos / np.power(10000, 2 * i / embed_dim)) if i % 2 == 0 else
         np.cos(pos / np.power(10000, 2 * i / embed_dim))
         for i in range(embed_dim)]
        for pos in range(max_seq_len)] for i in range(inputs.shape[0])])

    return torch.tensor(positional_encoding)

你可能感兴趣的:(深度学习经验总结,pytorch,transformer,深度学习)