在上篇Seq2Seq的文章中我们介绍了怎么用encoder-decoder框架实现机器翻译任务,现在加上注意力机制
class Attention(nn.Module):
def __init__(self, enc_hidden_size, dec_hidden_size):
super(Attention, self).__init__()
self.enc_hidden_size = enc_hidden_size
self.dec_hidden_size = dec_hidden_size
self.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False)
self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size)
def forward(self, output, context, mask):
# output: batch_size, output_len, dec_hidden_size
# context: batch_size, context_len, 2*enc_hidden_size
batch_size = output.size(0)
output_len = output.size(1)
input_len = context.size(1)
context_in = self.linear_in(context.view(batch_size*input_len, -1)).view(
batch_size, input_len, -1) # batch_size, context_len, dec_hidden_size
# context_in.transpose(1,2): batch_size, dec_hidden_size, context_len
# output: batch_size, output_len, dec_hidden_size
attn = torch.bmm(output, context_in.transpose(1,2))
# batch_size, output_len, context_len
attn.data.masked_fill(mask, -1e6)
attn = F.softmax(attn, dim=2)
# batch_size, output_len, context_len
context = torch.bmm(attn, context)
# batch_size, output_len, enc_hidden_size
output = torch.cat((context, output), dim=2) # batch_size, output_len, hidden_size*2
output = output.view(batch_size*output_len, -1)
output = torch.tanh(self.linear_out(output))
output = output.view(batch_size, output_len, -1)
return output, attn
class Decoder(nn.Module):
def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
super(Decoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.attention = Attention(enc_hidden_size, dec_hidden_size)
self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
self.out = nn.Linear(dec_hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, ctx, ctx_lengths, y, y_lengths, hid):
sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
y_sorted = y[sorted_idx.long()]
hid = hid[:, sorted_idx.long()]
y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, output_length, embed_size
packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
out, hid = self.rnn(packed_seq, hid)
unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
output_seq = unpacked[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()
mask = self.create_mask(y_lengths, ctx_lengths)
output, attn = self.attention(output_seq, ctx, mask)
output = F.log_softmax(self.out(output), -1)
return output, hid, attn
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, x_lengths, y, y_lengths):
encoder_out, hid = self.encoder(x, x_lengths)
output, hid, attn = self.decoder(ctx=encoder_out,
ctx_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
hid=hid)
return output, attn