视频地址:https://www.bilibili.com/video/BV1vz4y1R7Mm?p=7
先去
https://github.com/ZeweiChu/PyTorch-Course/tree/master/notebooks
下载数据集(nmt文件夹)
import os
import sys
import math
from collections import Counter
import numpy as np
import random
import torch
import torch.nn as nn
import nltk
print("torch", torch.__version__)
print("nltk", nltk.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch 1.2.0+cu92
nltk 3.5
def load_data(in_file):
"""读取训练数据,分词,返回结果"""
cn = []
en = []
num_examples = 0
with open(in_file, 'r', encoding='utf8') as f:
for line in f:
line = line.strip().split('\t')
en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
return en, cn
train_file = "nmt/en-cn/train.txt"
dev_file = "nmt/en-cn/dev.txt"
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)
train_en[:3]
[['BOS', 'anyone', 'can', 'do', 'that', '.', 'EOS'],
['BOS', 'how', 'about', 'another', 'piece', 'of', 'cake', '?', 'EOS'],
['BOS', 'she', 'married', 'him', '.', 'EOS']]
UNK_IDX = 0
PAD_IDX = 1
def build_dict(sentences, max_words=50000):
"""构建词表"""
word_count = Counter()
for sentence in sentences:
for s in sentence:
word_count[s] += 1
ls = word_count.most_common(max_words)
total_words = len(ls) + 2
word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
word_dict["UNK"] = UNK_IDX
word_dict["PAD"] = PAD_IDX
return word_dict, total_words
en_dict, en_total_words = build_dict(train_en)
cn_dict, cn_total_words = build_dict(train_cn)
inv_en_dict = {v:k for k,v in en_dict.items()}
inv_cn_dict = {v:k for k,v in cn_dict.items()}
def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):
"""对句子进行编码"""
lenght = len(en_sentences)
out_en_sentences = [[en_dict.get(w, UNK_IDX) for w in sent] for sent in en_sentences]
out_cn_sentences = [[cn_dict.get(w, UNK_IDX) for w in sent] for sent in cn_sentences]
def len_sort(seq):
"""对序号进行排序,因为下面要用两次序号"""
return sorted(range(len(seq)), key=lambda x: len(seq[x]))
if sort_by_len:
sorted_index = len_sort(out_en_sentences)
out_en_sentences = [out_en_sentences[i] for i in sorted_index]
out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
return out_en_sentences, out_cn_sentences
train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)
k = 10000
print(" ".join([inv_en_dict[i] for i in train_en[k]]))
print(" ".join([inv_cn_dict[i] for i in train_cn[k]]))
BOS for what purpose did he come here ? EOS
BOS 他 来 这 里 的 目 的 是 什 么 ? EOS
def get_minibatches(n, minibatch_size, shuffle):
"""返回[[3,4,5], [0,1,2], [6,7,8]]这种数据"""
idx_list = np.arange(0, n, minibatch_size)
if shuffle:
np.random.shuffle(idx_list)
minibatches = []
for idx in idx_list:
minibatches.append(np.arange(idx, min(idx+minibatch_size,n)))
return minibatches
def prepare_data(seqs):
"""将长度不等的序列补齐,补0"""
lengths = [len(seq) for seq in seqs]
n_samples = len(seqs)
max_len = np.max(lengths)
x = np.zeros((n_samples, max_len)).astype(np.int32)
x_lengths = np.array(lengths).astype(np.int32)
for idx, seq in enumerate(seqs):
x[idx, :lengths[idx]] = seq
return x, x_lengths # 保留x_lengths用于mask
def gen_examples(en_sentences, cn_sentences, batch_size, shuffle=False):
minibatches = get_minibatches(len(en_sentences), batch_size, shuffle)
all_ex = []
for minibatch in minibatches:
# 取出每个mini-batch对应的句子
mb_en_sentences = [en_sentences[t] for t in minibatch]
mb_cn_sentences = [cn_sentences[t] for t in minibatch]
mb_x, mb_x_len = prepare_data(mb_en_sentences)
mb_y, mb_y_len = prepare_data(mb_cn_sentences)
all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
return all_ex
batch_size = 64
train_data = gen_examples(train_en, train_cn, batch_size, shuffle=True)
dev_data = gen_examples(dev_en, dev_cn, batch_size, shuffle=False)
class SimpleEncoder(nn.Module):
def __init__(self, vocab_size, hidden_size, dropout=0.2):
super(SimpleEncoder, self).__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
self.dropout = nn.Dropout(dropout)
def forward(self, x, lengths):
# 把batch里面的seq按照长度排序, 是pack_padded_sequence要求的
sorted_len, sorted_idx = lengths.sort(0, descending=True)
x_sorted = x[sorted_idx]
embedded = self.dropout(self.embed(x_sorted))
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
packed_out, hid = self.rnn(packed_embedded)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
out = out[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()
return out, hid[[-1]]
class SimpleDecoder(nn.Module):
def __init__(self, vocab_size, hidden_size, dropout=0.2):
super(SimpleDecoder, self).__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.rnn = nn.GRU(2*hidden_size, hidden_size, batch_first=True)
self.fc = nn.Linear(2*hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, y, y_lengths, hid):
sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
y_sorted = y[sorted_idx] # 隐状态也要调顺序
embedded = self.dropout(self.embed(y_sorted)) # batch_size, y_lengths, hidden_size
hid = hid[:, sorted_idx]
embedded = torch.cat([embedded, hid.squeeze(0).unsqueeze(1).expand_as(embedded)], 2) # batch_size, y_lengths, hidden_size*2
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
packed_out, hid2 = self.rnn(packed_embedded, hid)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
out = out[original_idx.long()].contiguous() # batch_size, y_lengths, hidden_size
hid2 = hid2[:, original_idx.long()].contiguous() # 1, batch_size, hidden_size, 隐状态本来就没有长度维度
out = torch.cat([out, hid.squeeze(0).unsqueeze(1).expand_as(out)], 2)
out = self.fc(out) # batch_size, y_lengths, vocab_size
out = nn.functional.log_softmax(out, -1) # log_softmax默认对第0个维度进行softmax,用-1指定为最后1维
return out, hid2
class SimpleSeq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(SimpleSeq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, x_lengths, y, y_lengths):
encoder_out, hid = self.encoder(x, x_lengths)
output, hid = self.decoder(y, y_lengths, hid)
return output, None
def translate(self, x, x_lengths, y, max_length=10):
encoder_out, hid = self.encoder(x, x_lengths)
preds = []
batch_size = x.shape[0]
attns = []
for i in range(max_length):
output, pid = self.decoder(y=y, y_lengths=torch.ones(batch_size).long().to(device), hid=hid)
y = output.max(2)[1].view(batch_size, 1)
preds.append(y)
return torch.cat(preds, 1), None
class LanguageModelCriterion(nn.Module):
def __init__(self):
super(LanguageModelCriterion, self).__init__()
def forward(self, input, target, mask):
# input: (batch_size * seq_len) * vocab_size
input = input.contiguous().view(-1, input.size(2))
# target: batch_size * 1
target = target.contiguous().view(-1, 1)
mask = mask.contiguous().view(-1, 1)
output = -input.gather(1, target) * mask
output = torch.sum(output) / torch.sum(mask)
return output
dropout = 0.2
hidden_size = 100
encoder = SimpleEncoder(vocab_size=en_total_words, hidden_size=hidden_size, dropout=dropout)
decoder = SimpleDecoder(vocab_size=cn_total_words, hidden_size=hidden_size, dropout=dropout)
model = SimpleSeq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
def evaluate(model, data):
model.eval()
total_num_words = 0
total_loss = 0
with torch.no_grad():
for it, (x, x_len, y, y_len) in enumerate(data):
x = torch.from_numpy(x).to(device).long()
x_len = torch.from_numpy(x_len).to(device).long()
input = torch.from_numpy(y[:, :-1]).to(device).long()
output = torch.from_numpy(y[:, 1:]).to(device).long()
y_len = torch.from_numpy(y_len - 1).to(device).long()
pred, attn = model(x, x_len, input, y_len)
out_mask = torch.arange(y_len.max().item(), device=device)[None, :] < y_len[:, None]
out_mask = out_mask.float()
loss = loss_fn(pred, output, out_mask)
num_words = torch.sum(y_len).item()
total_loss += loss.item() * num_words
total_num_words += num_words
print(f"Evaluation loss {total_loss/total_num_words}")
def train(model, data, num_epochs=30):
for epoch in range(num_epochs):
model.train()
total_num_words = 0
total_loss = 0
for it, (x, x_len, y, y_len) in enumerate(data):
x = torch.from_numpy(x).to(device).long()
x_len = torch.from_numpy(x_len).to(device).long()
input = torch.from_numpy(y[:, :-1]).to(device).long()
output = torch.from_numpy(y[:, 1:]).to(device).long()
y_len = torch.from_numpy(y_len - 1).to(device).long()
pred, attn = model(x, x_len, input, y_len)
out_mask = torch.arange(y_len.max().item(), device=device)[None, :] < y_len[:, None]
out_mask = out_mask.float()
loss = loss_fn(pred, output, out_mask)
num_words = torch.sum(y_len).item()
total_loss += loss.item() * num_words
total_num_words += num_words
# 更新模型
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 5.)
optimizer.step()
if it % 20 == 0:
print(f"Epoch {epoch} iteration {it} loss {loss.item()}")
print(f"Epoch {epoch} Tranning loss {total_loss/total_num_words}")
evaluate(model, dev_data)
train(model, train_data, num_epochs=2)
Epoch 0 iteration 0 loss 4.1712327003479
Epoch 0 iteration 20 loss 4.116885185241699
Epoch 0 iteration 40 loss 3.6872975826263428
Epoch 0 iteration 60 loss 3.561073064804077
Epoch 0 iteration 80 loss 3.853266954421997
Epoch 0 iteration 100 loss 3.8749101161956787
Epoch 0 iteration 120 loss 3.787318468093872
Epoch 0 iteration 140 loss 3.8778975009918213
Epoch 0 iteration 160 loss 3.499333143234253
Epoch 0 iteration 180 loss 3.4058444499969482
Epoch 0 iteration 200 loss 3.708536386489868
Epoch 0 iteration 220 loss 3.7492809295654297
Epoch 0 Tranning loss 4.090564886316073
Evaluation loss 3.973093799748185
Epoch 1 iteration 0 loss 3.813424825668335
Epoch 1 iteration 20 loss 3.7274014949798584
Epoch 1 iteration 40 loss 3.332293748855591
Epoch 1 iteration 60 loss 3.245396375656128
Epoch 1 iteration 80 loss 3.549337863922119
Epoch 1 iteration 100 loss 3.5778181552886963
Epoch 1 iteration 120 loss 3.4655778408050537
Epoch 1 iteration 140 loss 3.6115944385528564
Epoch 1 iteration 160 loss 3.233412027359009
Epoch 1 iteration 180 loss 3.12416672706604
Epoch 1 iteration 200 loss 3.426990032196045
Epoch 1 iteration 220 loss 3.5066943168640137
Epoch 1 Tranning loss 3.796283439917883
Evaluation loss 3.787926131268771
print(" ".join([inv_en_dict[i] for i in dev_en[0]]))
print("".join([inv_cn_dict[i] for i in dev_cn[0]]))
x = torch.from_numpy(np.array([dev_en[0], dev_en[1]])).long().to(device)
x_len = torch.from_numpy(np.array([len(dev_en[0]), len(dev_en[1])])).long().to(device)
bos = torch.Tensor([[cn_dict["BOS"]], [cn_dict["BOS"]]]).long().to(device)
translation, attn = model.translate(x, x_len, bos)
translation = [[inv_cn_dict[i] for i in sentence] for sentence in translation.data.cpu().numpy()]
print(translation)
BOS look around . EOS
BOS四处看看。EOS
tensor([[2],
[2]])
tensor([5, 5])
[['她', '的', '人', '都', '喜', '欢', '迎', '是', '我', '們'], ['汤', '姆', '在', '这', '个', '人', '都', '是', '汤', '姆']]
class Attention(nn.Module):
"""
Luong Attention.
根据 context vectors 和当前的输出 hidden_states,计算输出
"""
def __init__(self, encoder_hidden_size, decoder_hidden_size):
super(Attention, self).__init__()
self.encoder_hidden_size = encoder_hidden_size
self.decoder_hidden_size = decoder_hidden_size
self.linear_in = nn.Linear(encoder_hidden_size*2, decoder_hidden_size, bias=False)
self.linear_out = nn.Linear(encoder_hidden_size*2 + decoder_hidden_size, decoder_hidden_size)
def forward(self, output, context, mask):
# output: batch_size, output_len, decoder_hidden_size
# context: batch_size, context_len, encoder_hidden_size
batch_size = output.size(0)
output_len = output.size(1)
input_len = context.size(1)
# context_in.transpose(1,2): batch_size, dec_hidden_size, context_len
# output: batch_size, output_len, dec_hidden_size
context_in = self.linear_in(context.view(batch_size * input_len, -1)).view(batch_size, input_len, -1)
attn = torch.bmm(output, context_in.transpose(1, 2)) # batch matrix-matrix product
attn.data.masked_fill(mask, -1e6) # 把Mask的位置设置成非常小的值,不影响softmax
# 一直根据上面的公式计算,我看不下去了……
attn = nn.functional.softmax(attn, dim=2)
context = torch.bmm(attn, context)
output = torch.cat((context, output), dim=2)
output = output.view(batch_size * output_len, -1)
output = torch.tanh(self.linear_out(output))
output = output.view(batch_size, output_len, -1)
return output, attn
class Encoder(nn.Module):
def __init__(self, vocab_size, embed_size, encoder_hidden_size, decoder_hidden_size, dropout=0.2):
super(Encoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size, encoder_hidden_size, batch_first=True, bidirectional=True)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(encoder_hidden_size*2, decoder_hidden_size)
def forward(self, x, lengths):
sorted_len, sorted_idx = lengths.sort(0, descending=True)
x_sorted = x[sorted_idx]
embedded = self.dropout(self.embed(x_sorted))
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
packed_out, hid = self.rnn(packed_embedded)
out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
out = out[original_idx.long()].contiguous()
hid = hid[:, original_idx.long()].contiguous()
hid = torch.cat([hid[-2], hid[-1]], dim=1)
hid = torch.tanh(self.fc(hid)).unsqueeze(0)
return out, hid
class Decoder(nn.Module):
def __init__(self, vocab_size, embed_size, encoder_hidden_size, decoder_hidden_size, dropout=0.2):
super(Decoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.attention = Attention(encoder_hidden_size, decoder_hidden_size)
self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
self.fc = nn.Linear(decoder_hidden_size, vocab_size)
self.dropout = nn.Dropout(dropout)
def create_mask(self, x_len, y_len):
# a mask of shape x_len * y_len
device = x_len.device
max_x_len = x_len.max()
max_y_len = y_len.max()
x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None]
y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None]
mask = (~(x_mask[:, :, None] * y_mask[:, None, :]))
return mask
def forward(self, context, context_lengths, y, y_lengths, hid):
sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
y_sorted = y[sorted_idx] # 隐状态也要调顺序
y_sorted = self.dropout(self.embed(y_sorted)) # batch_size, y_lengths, hidden_size
hid = hid[:, sorted_idx]
packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
out, hid2 = self.rnn(packed_seq, hid)
unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
_, original_idx = sorted_idx.sort(0, descending=False)
output_seq = unpacked[original_idx.long()].contiguous() # batch_size, y_lengths, hidden_size
hid = hid[:, original_idx.long()].contiguous() # 1, batch_size, hidden_size, 隐状态本来就没有长度维度
mask = self.create_mask(y_lengths, context_lengths)
output, attention = self.attention(output_seq, context, mask)
output = nn.functional.log_softmax(self.fc(output), -1) # log_softmax默认对第0个维度进行softmax,用-1指定为最后1维
return output, hid, attention
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, x_lengths, y, y_lengths):
encoder_out, hid = self.encoder(x, x_lengths)
output, hid, attn = self.decoder(
context=encoder_out,
context_lengths=x_lengths,
y=y,
y_lengths=y_lengths,
hid=hid)
return output, attn
def translate(self, x, x_lengths, y, max_length=100):
encoder_out, hid = self.encoder(x, x_lengths)
preds = []
batch_size = x.shape[0]
attns = []
for i in range(max_length):
output, hid, attn = self.decoder(context=encoder_out,
context_lengths=x_lengths,
y=y,
y_lengths=torch.ones(batch_size).long().to(device), hid=hid)
y = output.max(2)[1].view(batch_size, 1)
preds.append(y)
attns.append(attn)
return torch.cat(preds, 1), torch.cat(attns, 1)
"""训练"""
dropout = 0.2
embed_size = hidden_size = 100
encoder = Encoder(vocab_size=en_total_words,
embed_size=embed_size,
encoder_hidden_size=hidden_size,
decoder_hidden_size=hidden_size,
dropout=dropout)
decoder = Decoder(vocab_size=cn_total_words,
embed_size=embed_size,
encoder_hidden_size=hidden_size,
decoder_hidden_size=hidden_size,
dropout=dropout)
model = Seq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
train(model, train_data, num_epochs=3)
Epoch 0 iteration 0 loss 8.088303565979004
Epoch 0 iteration 20 loss 6.1903395652771
...
Epoch 0 iteration 220 loss 4.88083028793335
Epoch 0 Tranning loss 5.531404307857315
Evaluation loss 5.054636112028756
Epoch 1 iteration 0 loss 5.031433582305908
Epoch 1 iteration 20 loss 5.012241363525391
...
Epoch 1 iteration 220 loss 4.389835357666016
Epoch 1 Tranning loss 4.855371610606542
Evaluation loss 4.594973376533105
Epoch 2 iteration 0 loss 4.569258213043213
Epoch 2 iteration 20 loss 4.4955363273620605
...
Epoch 2 iteration 220 loss 4.008622646331787
Epoch 2 Tranning loss 4.423912357264402
Evaluation loss 4.228922175465478
k = 120
print(" ".join([inv_en_dict[i] for i in dev_en[k]]))
print("".join([inv_cn_dict[i] for i in dev_cn[k]]))
x = torch.from_numpy(np.array([dev_en[k]])).long().to(device)
x_len = torch.from_numpy(np.array([len(dev_en[k])])).long().to(device)
bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)
translation, attn = model.translate(x, x_len, bos)
translation = [[inv_cn_dict[i] for i in sentence] for sentence in translation.data.cpu().numpy()]
print(translation)
BOS i like your room . EOS
BOS我喜欢你的房间。EOS
[['我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我']]