Bert+Bilstm+attention模型

import torch
import torch.nn as nn

class BERTLSTMAttention(nn.Module):
    def __init__(self, bert_model, hidden_size, output_size):
        super(BERTLSTMAttention, self).__init__()
        self.bert = bert_model
        self.lstm = nn.LSTM(input_size=hidden_size,
                            hidden_size=hidden_size,
                            num_layers=1,
                            bidirectional=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        # BERT 编码
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

        # LSTM 处理
        lstm_out, _ = self.lstm(pooled_output)

        # Attention
        attention_weights = torch.nn.functional.softmax(lstm_out, dim=1)
        weighted_output = lstm_out * attention_weights

你可能感兴趣的:(Pytorch-深度学习,深度学习,python,人工智能)