实现英文句子翻译成中文句子,除了机器翻译,seq2seq其实可以用在很多地方,例如自动对话机器人,文档摘要自动生成,图片描述自动生成等任务
①读取数据(英文句子,中文句子)
②构建单词表、词编码
③构造batch:按长度排序,每个batch内句长相似,pad补齐
④定义模型
⑤定义损失与优化
⑥训练与评估
def load_data(in_file):
cn=[]
en=[]
num_examples=0
with open(in_file,'r')as f:
for line in f:
line=line.split('/t')
en.append(['BOS']+nltk.word_tokenize(line[0].lower())+["EOS"])
cn.append(['BOS']+[i for i in line[1]]+["EOS"])
return en,cn
train_en,train_cn=load_data("train.txt")
dev_en,dev_cn=load_data("dev.txt")
class Encoder(nn.Module):
def __init__(self,vocab_size,hidden_size,dropout=0.5):
super(self,Encoder).__init__()
self.embed=nn.Embedding(vocab_size,hidden_size)
self.rnn=nn.GRU(hidden_size,hidden_size,batch_first=True)
self.dropout=nn.Dropout(dropout)
def forward(self,x,x_len):
sorted_len,sorted_idx = x_len.sort(0,descending=True)
x_sorted=x[sorted_idx]
embedded=self.dropout(self.embed(x_sorted))
packed_embedded=nn.utils.rnn.pack_padded_sequence(embedded,sorted_len.long().cpu.data.numpy(),batch_first=Trye)
packed_out,hid=self.rnn(packed_embedded)
out,_=nn.utils.rnn.pad_packed_sequence(packed_out,batch_firset=True)
_,orig_idx=sorted_idx.sort(0,descending=False)
out=out[orig_idx.long()].contiguous()
hid=hid[:,orig_idx.long()].contigous()
return out,hid[[-1]]
class Decoder(nn.Module):
def __init__(self,vocab_size,hidden_size,dropout=0.2):
super(Decoder,self).__init__()
self.embed=nn.Embedding(vocab_size,hidden_size)
self.rnn=nn.GRU(hidden_size,hidden_size,batch_first=True)
self.dropout=nn.Dropout(dropout)
self.linear=nn.Linear(hidden_size,vocab_size)
def forward(self,y,y_len,hid):
sorted_len,sorted_idx=y_len.sort(0,descending=True)
y_sorted=y[sorted_idx.long()]
hid=hid[:,sorted_idx.long()]
y_embed=self.dropout(self.embed(y_sorted))
packed_seq=nn.utils.rnn.pack_padded_sequence(y_embed,sorted_len.long().cpu().data.numpy(),batch_first=True)
output,hid=self.rnn(packed_seq)
out,_=nn.utils.rnn.pack_padded_sequence(output,batch_first=True)
_,orig_idx=sorted_idx.sort(0,descending=False)
out=out[orig_idx].contigous()
hid=hid[:,orig_idx.long()].contigous()
out= F.log_softmax(self.linear(out),-1)
return out,hid
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, x_lengths, y, y_lengths):
encoder_out, hid = self.encoder(x, x_lengths)
output, hid = self.decoder(y=y,
y_lengths=y_lengths,
hid=hid)
return output, None
// masked cross entropy loss
class LanguageModelCriterion(nn.Module):
def __init__(self):
super(LanguageModelCriterion, self).__init__()
def forward(self, input, target, mask):
# input: (batch_size * seq_len) * vocab_size
input = input.contiguous().view(-1, input.size(2))
# target: batch_size * 1
target = target.contiguous().view(-1, 1)
mask = mask.contiguous().view(-1, 1)
output = -input.gather(1, target) * mask # 在vocab_size中选索引为target的单词*mask
output = torch.sum(output) / torch.sum(mask)
return output
for epoch in range(num_epochs):
model.train()
total_num_words = total_loss = 0.
for it, (mb_x, mb_x_len, mb_y, mb_y_len) in enumerate(data):
mb_x = torch.from_numpy(mb_x).to(device).long()
mb_x_len = torch.from_numpy(mb_x_len).to(device).long()
mb_input = torch.from_numpy(mb_y[:, :-1]).to(device).long()# decoder用从后往前的y预测从前往后的y
mb_output = torch.from_numpy(mb_y[:, 1:]).to(device).long()
mb_y_len = torch.from_numpy(mb_y_len-1).to(device).long()#数组mb_y_len里每个值减1
mb_y_len[mb_y_len<=0] = 1 #mb_y_len里面凡是小于等于0的值都用1替换
mb_pred, attn = model(mb_x, mb_x_len, mb_input, mb_y_len)
mb_out_mask = torch.arange(mb_y_len.max().item(), device=device)[None, :] < mb_y_len[:, None]
mb_out_mask = mb_out_mask.float()
loss = loss_fn(mb_pred, mb_output, mb_out_mask)
num_words = torch.sum(mb_y_len).item()
total_loss += loss.item() * num_words
total_num_words += num_words
# 更新模型
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)#防止模型gradient太大
optimizer.step()
if it % 100 == 0:
print("Epoch", epoch, "iteration", it, "loss", loss.item())
print("Epoch", epoch, "Training loss", total_loss/total_num_words)
if epoch % 5 == 0:
evaluate(model, dev_data)