1. 简述
- 优点:
- 可以实现输入序列和输出序列不等长映射。
- 缺点:
- Encoder和Decoder采用RNN,不能并行。
- 存在一定程度梯度消失和梯度爆炸的可能。
- 将输入序列的信息压缩到状态向量中,存在信息丢失问题。
2. 结构
3. pytorch 实现
import torch
from torch import nn
编码器
class Seq2SeqEncoder(nn.Module):
"""
Seq2Seq循环神经网络编码器
Args:
vocab_size (int): vocab大小。
embedding_dim (int): 嵌入层输出维度。
hidden_size (int): RNN输出维度。
num_layers (int): RNN层数。
dropout (float): dropout。
Inputs:
x: (batch_size, seq_len)
Outputs:
output: (batch_size, seq_len, hidden_size)
state: (num_layers, batch_size, hidden_size)
"""
def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout=0, **kwargs):
super(Seq2SeqEncoder, self).__init__()
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
self.rnn = nn.GRU(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,
dropout=dropout)
def forward(self, x, *args):
x = self.embedding(x)
output, state = self.rnn(x)
return output, state
解码器
class Seq2SeqDecoder(nn.Module):
"""
Seq2Seq循环神经网络解码器
Args:
vocab_size (int): vocab大小。
embedding_dim (int): 嵌入层输出维度。
hidden_size (int): RNN输出维度。
num_layers (int): RNN层数。
dropout (float): dropout。
Inputs:
X: (batch_size, seq_len)
state: (num_layers, batch_size, hidden_size)
Outputs:
output: (batch_size, num_steps, vocab_size)
state: (num_layers, batch_size, hidden_size)
"""
def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout=0, **kwargs):
super(Seq2SeqDecoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.GRU(input_size=embedding_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers,
dropout=dropout, batch_first=True)
self.fc = nn.Linear(hidden_size, vocab_size)
def init_state(self, enc_outputs, *args):
enc_output, enc_state = enc_outputs
return enc_state
def forward(self, x, state):
x = self.embedding(x)
context = state[-1].repeat(x.shape[1], 1, 1)
context = context.permute(1, 0, 2)
x_and_context = torch.cat((x, context), 2)
output, state = self.rnn(x_and_context, state)
output = self.fc(output)
return output, state
合并
class EncoderDecoder(nn.Module):
"""
合并Encoder与Decoder
"""
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_x, dec_x, *args):
enc_outputs = self.encoder(enc_x, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_x, dec_state)
测试
batch_size = 32
seq_len = 15
vocab_size = 3000
embedding_dim = 100
hidden_size = 8
num_layers = 3
enc_x = torch.zeros(batch_size, seq_len, dtype=torch.long)
dec_x = torch.ones(batch_size, seq_len, dtype=torch.long)
encoder = Seq2SeqEncoder(vocab_size, embedding_dim, hidden_size, num_layers)
decoder = Seq2SeqDecoder(vocab_size, embedding_dim, hidden_size, num_layers)
model = EncoderDecoder(encoder, decoder)
output, state = model(enc_x, dec_x)
print(output.shape)
print(state.shape)
torch.Size([32, 15, 3000])
torch.Size([3, 32, 8])
4. 参考
- 《动手学深度学习》序列到序列学习(seq2seq)