本项目尽可能复现Luong的attention模型,数据集小,只有一万多个句子的训练数据,所以训练出来的模型效果并不好。如果想训练一个好一点的模型,可以参考下面的资料。
课件
论文
- Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation
- Effective Approaches to Attention-based Neural Machine Translation
- Neural Machine Translation by Jointly Learning to Align and Translate
PyTorch代码
- seq2seq-tutorial
- Tutorial from Ben Trevett
- IBM seq2seq
- OpenNMT-py 较好
更多关于Machine Translation
- Beam Search - Pointer network 文本摘要
- Copy Mechanism 文本摘要
- Converage Loss
- ConvSeq2Seq
- Transformer
- Tensor2Tensor
本项目的完整代码和数据集可见, 一键运行,开箱即食
github代码
import os
import sys
import math
from collections import Counter
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import nltk
def load_data(in_file):
cn = []
en = []
num_examples = 0
with open(in_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip().split("\t")
en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
return en, cn
train_file = "nmt/nmt/en-cn/train.txt"
dev_file = "nmt/nmt/en-cn/dev.txt"
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)
print(train_en[:10])
print(train_cn[:10])
UNK_IDX = 0
PAD_IDX = 1
def build_dict(sentences, max_words=50000):
word_count = Counter()
for sentence in sentences:
for s in sentence:
word_count[s] += 1
ls = word_count.most_common(max_words)
print(len(ls))
total_words = len(ls) + 2
word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
word_dict["UNK"] = UNK_IDX
word_dict["PAD"] = PAD_IDX
return word_dict, total_words
en_dict, en_total_words = build_dict(train_en)
cn_dict, cn_total_words = build_dict(train_cn)
inv_en_dict = {v: k for k, v in en_dict.items()}
inv_cn_dict = {v: k for k, v in cn_dict.items()}
print(en_total_words)
print(list(en_dict.items())[:10])
print(list(en_dict.items())[-10:])
print("---"*20)
print(cn_total_words)
print(list(cn_dict.items())[:10])
print(list(cn_dict.items())[-10:])
print("---"*20)
print(list(inv_en_dict.items())[:10])
print(list(inv_cn_dict.items())[:10])
def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):
length = len(en_sentences)
out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
out_cn_sentences = [[cn_dict.get(w, 0) for w in sent] for sent in cn_sentences]
def len_argsort(seq):
return sorted(range(len(seq)), key=lambda x: len(seq[x]))
if sort_by_len:
sorted_index = len_argsort(out_en_sentences)
out_en_sentences = [out_en_sentences[i] for i in sorted_index]
out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
return out_en_sentences, out_cn_sentences
train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)
seq = [5,4,6,9,10]
print(sorted(range(5), key=lambda x: seq[x]))
print(sorted(range(4), key=lambda x: seq[x]))
print(train_en[:10])
print(train_cn[:10])
print("---"*20)
k=10000
print([inv_cn_dict[i] for i in train_cn[k]])
print([inv_en_dict[i] for i in train_en[k]])
print(" ".join([inv_cn_dict[i] for i in train_cn[k]]))
print(" ".join([inv_en_dict[i] for i in train_en[k]]))
print(np.arange(0, 100, 15))
print(np.arange(0, 15))
def get_batches(n, batch_size, shuffle=True):
idx_list = np.arange(0, n, batch_size)
if shuffle:
np.random.shuffle(idx_list)
batches = []
for idx in idx_list:
batches.append(np.arange(idx, min(idx + batch_size, n)))
return batches
get_batches(100,15)
def sent_padding(seqs):
lengths = [len(seq) for seq in seqs]
n_samples = len(seqs)
max_len = np.max(lengths)
x = np.zeros((n_samples, max_len)).astype('int32')
x_lengths = np.array(lengths).astype("int32")
for idx, seq in enumerate(seqs):
x[idx, :lengths[idx]] = seq
return x, x_lengths
def gen_examples(en_sentences, cn_sentences, batch_size):
batches = get_batches(len(en_sentences), batch_size)
all_ex = []
for batch in batches:
mb_en_sentences = [en_sentences[t] for t in batch]
mb_cn_sentences = [cn_sentences[t] for t in batch]
mb_x, mb_x_len = sent_padding(mb_en_sentences)
mb_y, mb_y_len = sent_padding(mb_cn_sentences)
all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
return all_ex
batch_size = 64
train_data = gen_examples(train_en, train_cn, batch_size)
random.shuffle(train_data)
dev_data = gen_examples(dev_en, dev_cn, batch_size)
print(train_data[0][0].shape)
print(train_data[0][1].shape)
print(train_data[0][2].shape)
print(train_data[0][3].shape)
print(train_data[0])
class PlainEncoder(nn.Module):
def __init__(self, vocab_size, hidden_size, dropout=0.2):
super(PlainEncoder, self).__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.dropout = nn.Dropout(dropout)
def forward(self, x, lengths):
sorted_len, sorted_idx = lengths.sort(0, descending=True)
x_sorted = x[sorted_idx.long()]
embedded = self.dropout(self.embed(x_sorted))
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
packed_out, hid = self.rnn(packed_embedded)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
out = out[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()
return out, hid[[-1]]
class PlainDecoder(nn.Module):
def __init__(self, vocab_size, hidden_size, dropout=0.2):
super(PlainDecoder, self).__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.out = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, y, y_lengths, hid):
sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
y_sorted = y[sorted_idx.long()]
hid = hid[:, sorted_idx.long()]
y_sorted = self.dropout(self.embed(y_sorted))
packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
out, hid = self.rnn(packed_seq, hid)
unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
output_seq = unpacked[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()
output = F.log_softmax(self.out(output_seq), -1)
return output, hid
class PlainSeq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(PlainSeq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, x_lengths, y, y_lengths):
encoder_out, hid = self.encoder(x, x_lengths)
output, hid = self.decoder(y, y_lengths, hid)
return output, None
def translate(self, x, x_lengths, y, max_length=10):
encoder_out, hid = self.encoder(x, x_lengths)
preds = []
batch_size = x.shape[0]
attns = []
for i in range(max_length):
output, hid = self.decoder(y=y,
y_lengths=torch.ones(batch_size).long().to(y.device),
hid=hid)
y = output.max(2)[1].view(batch_size, 1)
preds.append(y)
return torch.cat(preds, 1), None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dropout = 0.2
hidden_size = 100
encoder = PlainEncoder(vocab_size=en_total_words,
hidden_size=hidden_size,
dropout=dropout)
decoder = PlainDecoder(vocab_size=cn_total_words,
hidden_size=hidden_size,
dropout=dropout)
model = PlainSeq2Seq(encoder, decoder)
class LanguageModelCriterion(nn.Module):
def __init__(self):
super(LanguageModelCriterion, self).__init__()
def forward(self, input, target, mask):
input = input.contiguous().view(-1, input.size(2))
target = target.contiguous().view(-1, 1)
mask = mask.contiguous().view(-1, 1)
output = -input.gather(1, target) * mask
output = torch.sum(output) / torch.sum(mask)
return output
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
def evaluate(model, data):
model.eval()
total_num_words = total_loss = 0.
with torch.no_grad():
for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
mb_x = torch.from_numpy(mb_x).to(device).long()
mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
mb_y_len[mb_y_len<=0] = 1
mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
mb_out_mask = mb_out_mask.float()
loss = loss_fn(mb_pred, mb_output, mb_out_mask)
num_words = torch.sum(mb_y_len).item()
total_loss += loss.item() * num_words
total_num_words += num_words
print("Evaluation loss", total_loss/total_num_words)
def train(model, data, num_epochs=2):
for epoch in range(num_epochs):
model.train()
total_num_words = total_loss = 0.
for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
mb_x = torch.from_numpy(mb_x).to(device).long()
mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()
mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()
mb_y_len[mb_y_len<=0] = 1
optimizer.zero_grad()
mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
mb_out_mask = mb_out_mask.float()
loss = loss_fn(mb_pred, mb_output, mb_out_mask)
num_words = torch.sum(mb_y_len).item()
total_loss += loss.item() * num_words
total_num_words += num_words
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
optimizer.step()
if it % 100 == 0:
print("Epoch", epoch, "iteration", it, "loss", loss.item())
print("Epoch", epoch, "Training loss", total_loss/total_num_words)
if epoch % 5 == 0:
evaluate(model, dev_data)
train(model, train_data, num_epochs=20)
def translate_dev(i):
en_sent = " ".join([inv_en_dict[w] for w in dev_en[i]])
print(en_sent)
cn_sent = " ".join([inv_cn_dict[w] for w in dev_cn[i]])
print("".join(cn_sent))
mb_x = torch.from_numpy(np.array(dev_en[i]).reshape(1, -1)).long().to(device)
mb_x_len = torch.from_numpy(np.array([len(dev_en[i])])).long().to(device)
bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)
translation, attn = model.translate(mb_x, mb_x_len, bos)
translation = [inv_cn_dict[i] for i in translation.data.cpu().numpy().reshape(-1)]
trans = []
for word in translation:
if word != "EOS":
trans.append(word)
else:
break
print("".join(trans))
for i in range(500,520):
translate_dev(i)
print()
class Encoder(nn.Module):
def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
super(Encoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size, enc_hidden_size, batch_first=True, bidirectional=True)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(enc_hidden_size * 2, dec_hidden_size)
def forward(self, x, lengths):
sorted_len, sorted_idx = lengths.sort(0, descending=True)
x_sorted = x[sorted_idx.long()]
embedded = self.dropout(self.embed(x_sorted))
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
packed_out, hid = self.rnn(packed_embedded)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
out = out[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()
hid = torch.cat([hid[-2], hid[-1]], dim=1)
hid = torch.tanh(self.fc(hid)).unsqueeze(0)
return out, hid
class Attention(nn.Module):
def __init__(self, enc_hidden_size, dec_hidden_size):
super(Attention, self).__init__()
self.enc_hidden_size = enc_hidden_size
self.dec_hidden_size = dec_hidden_size
self.linear_in = nn.Linear(enc_hidden_size*2, dec_hidden_size, bias=False)
self.linear_out = nn.Linear(enc_hidden_size*2 + dec_hidden_size, dec_hidden_size)
def forward(self, output, context, mask):
batch_size = output.size(0)
output_len = output.size(1)
input_len = context.size(1)
context_in = self.linear_in(context.view(batch_size*input_len, -1)).view(
batch_size, input_len, -1)
attn = torch.bmm(output, context_in.transpose(1,2))
attn.data.masked_fill(mask, -1e6)
attn = F.softmax(attn, dim=2)
context = torch.bmm(attn, context)
output = torch.cat((context, output), dim=2)
output = output.view(batch_size*output_len, -1)
output = torch.tanh(self.linear_out(output))
output = output.view(batch_size, output_len, -1)
return output, attn
class Decoder(nn.Module):
def __init__(self, vocab_size, embed_size, enc_hidden_size, dec_hidden_size, dropout=0.2):
super(Decoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.attention = Attention(enc_hidden_size, dec_hidden_size)
self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
self.out = nn.Linear(dec_hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def create_mask(self, x_len, y_len):
device = x_len.device
max_x_len = x_len.max()
max_y_len = y_len.max()
x_mask = torch.arange(max_x_len, device=device)[None, :] < x_len[:, None]
y_mask = torch.arange(max_y_len, device=device)[None, :] < y_len[:, None]
mask = ( ~ x_mask[:, :, None] * y_mask[:, None, :]).byte()
return mask
def forward(self, encoder_out, x_lengths, y, y_lengths, hid):
sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
y_sorted = y[sorted_idx.long()]
hid = hid[:, sorted_idx.long()]
y_sorted = self.dropout(self.embed(y_sorted))
packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
out, hid = self.rnn(packed_seq, hid)
unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
output_seq = unpacked[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()
mask = self.create_mask(y_lengths, x_lengths)
output, attn = self.attention(output_seq, encoder_out, mask)
output = F.log_softmax(self.out(output), -1)
return output, hid, attn
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, x_lengths, y, y_lengths):
encoder_out, hid = self.encoder(x, x_lengths)
output, hid, attn = self.decoder(encoder_out=encoder_out,
x_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
hid=hid)
return output, attn
def translate(self, x, x_lengths, y, max_length=100):
encoder_out, hid = self.encoder(x, x_lengths)
preds = []
batch_size = x.shape[0]
attns = []
for i in range(max_length):
output, hid, attn = self.decoder(encoder_out=encoder_out,
x_lengths=x_lengths,
y=y,
y_lengths=torch.ones(batch_size).long().to(y.device),
hid=hid)
y = output.max(2)[1].view(batch_size, 1)
preds.append(y)
attns.append(attn)
return torch.cat(preds, 1), torch.cat(attns, 1)
dropout = 0.2
embed_size = hidden_size = 100
encoder = Encoder(vocab_size=en_total_words,
embed_size=embed_size,
enc_hidden_size=hidden_size,
dec_hidden_size=hidden_size,
dropout=dropout)
decoder = Decoder(vocab_size=cn_total_words,
embed_size=embed_size,
enc_hidden_size=hidden_size,
dec_hidden_size=hidden_size,
dropout=dropout)
model = Seq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
train(model, train_data, num_epochs=30)
for i in range(100,120):
translate_dev(i)
print()