Seq2Seq+GRU/LSTM pytorch实现

1. 简述

  • 优点
    1. 可以实现输入序列和输出序列不等长映射。
  • 缺点
    1. Encoder和Decoder采用RNN,不能并行。
    2. 存在一定程度梯度消失和梯度爆炸的可能。
    3. 将输入序列的信息压缩到状态向量中,存在信息丢失问题。

2. 结构

Seq2Seq+GRU/LSTM pytorch实现_第1张图片

3. pytorch 实现

import torch
from torch import nn

编码器

class Seq2SeqEncoder(nn.Module):
    """
    Seq2Seq循环神经网络编码器

    Args:
        vocab_size (int): vocab大小。
        embedding_dim (int): 嵌入层输出维度。
        hidden_size (int): RNN输出维度。
        num_layers (int): RNN层数。
        dropout (float): dropout。

    Inputs:
        x: (batch_size, seq_len)

    Outputs:
        output: (batch_size, seq_len, hidden_size)
        state: (num_layers, batch_size, hidden_size)

    """

    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout=0, **kwargs):
        super(Seq2SeqEncoder, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
                          dropout=dropout)

    def forward(self, x, *args):
        # x->(batch_size, seq_len)
        x = self.embedding(x)
        # x->(batch_size, seq_len, embedding_dim)
        output, state = self.rnn(x)
        # output->(batch_size, seq_len, hidden_size)
        # state->(num_layers, batch_size, hidden_size)
        return output, state

解码器

class Seq2SeqDecoder(nn.Module):
    """
    Seq2Seq循环神经网络解码器

    Args:
        vocab_size (int): vocab大小。
        embedding_dim (int): 嵌入层输出维度。
        hidden_size (int): RNN输出维度。
        num_layers (int): RNN层数。
        dropout (float): dropout。

    Inputs:
        X: (batch_size, seq_len)
        state: (num_layers, batch_size, hidden_size)

    Outputs:
        output: (batch_size, num_steps, vocab_size)
        state: (num_layers, batch_size, hidden_size)

    """

    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout=0, **kwargs):
        super(Seq2SeqDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.GRU(input_size=embedding_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers,
                          dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def init_state(self, enc_outputs, *args):
        enc_output, enc_state = enc_outputs
        return enc_state

    def forward(self, x, state):
        # x->(batch_size, seq_len)
        x = self.embedding(x)
        # X->(batch_size, seq_len, embedding_dim)

        # state[-1]->(batch_size, hidden_size)
        context = state[-1].repeat(x.shape[1], 1, 1)
        # context->(seq_len, batch_size, hidden_size)
        context = context.permute(1, 0, 2)
        # context->(batch_size, seq_len, hidden_size)
        # 拼接x与编码器最后隐状态
        x_and_context = torch.cat((x, context), 2)
        # X_and_context (batch_size, seq_len, embedding_dim + hidden_size)
        output, state = self.rnn(x_and_context, state)
        # output->(batch_size, seq_len, hidden_size)
        # state->(num_layers, batch_size, hidden_size)
        output = self.fc(output)
        # output->(batch_size, seq_len, vocab_size)
        return output, state

合并

class EncoderDecoder(nn.Module):
    """
    合并Encoder与Decoder
    """

    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_x, dec_x, *args):
        enc_outputs = self.encoder(enc_x, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_x, dec_state)

测试

batch_size = 32
seq_len = 15
vocab_size = 3000
embedding_dim = 100
hidden_size = 8
num_layers = 3

enc_x = torch.zeros(batch_size, seq_len, dtype=torch.long)
dec_x = torch.ones(batch_size, seq_len, dtype=torch.long)

encoder = Seq2SeqEncoder(vocab_size, embedding_dim, hidden_size, num_layers)
decoder = Seq2SeqDecoder(vocab_size, embedding_dim, hidden_size, num_layers)
model = EncoderDecoder(encoder, decoder)

output, state = model(enc_x, dec_x)
print(output.shape)
print(state.shape)
torch.Size([32, 15, 3000])
torch.Size([3, 32, 8])

4. 参考

  • 《动手学深度学习》序列到序列学习(seq2seq)

你可能感兴趣的:(NLP,pytorch,lstm,gru,nlp,自然语言处理)