【Pytorch-NLP实战系列】:Seq2Seq训练输出反义词(不到百行代码)

总述:

用RNN编码解码机制训练一个输出反义词的模型,目的在于熟悉pytorch的使用,代码中有新手不太懂的函数都引用了博客,请放心食用。

千言万语皆在代码中:

#coding=utf-8
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable


char_arr = [c for c in 'SEPabcdefghijklmnopqrstuvwxyz']#所有我们需要的字符,S表示输出的开始,P用来填充,E代表输出的结束
num_dic = {n: i for i, n in enumerate(char_arr)}#关于enumerate函数,请见解释1
#此处的输出为{'S':1,'E':2,'P':3,'a':4......'z':29}

seq_data = [['man', 'women'], ['black', 'white'], ['king', 'queen'], ['girl', 'boy'], ['up', 'down'], ['high', 'low'],['good','bad']] #我们所需要的反义词数据

#模型的一些参数
n_hideen = 128  #隐藏层节点的个数
n_class = len(char_arr) #总的字符数,也就是用one-hot编码的话,每一个字符的编码是29维的
batch_size = len(seq_data)#训练数据的多少
max_len = 5 #单词最大长度,不够拿'P'填充


def make_batch(seq_data):
    input_batch,output_batch,target_batch = [],[],[]
    for seq in seq_data:
        seq[0] = seq[0]+(max_len-len(seq[0]))*'P' #不足的长度由P来填充
        seq[1] = seq[1]+(max_len-len(seq[1]))*'P'
        input = [num_dic[s] for s in seq[0]]
        output = [num_dic[s] for s in 'S'+seq[1]]#创造了一个数字列表,每个数字代表在num_dic中对应seq各个字符的下标,可以输出看一看
        target = [num_dic[s] for s in seq[1]+'E']#target不是one-hot编码了
        input_batch.append(np.eye(n_class)[input])#关于np.eye的用法,请看解释二。此处获得了一个len(input)*n_class的矩阵
        output_batch.append(np.eye(n_class)[output])
        target_batch.append(target)

    return  Variable(torch.Tensor(input_batch)),Variable(torch.Tensor(output_batch)),Variable(torch.LongTensor(target_batch))
#返回Tensor类型,为啥有LongTensor?因为计算交叉熵的时候,第二个参数需要是LongTensor类型的。

#构建模型
class Seqtoseq(nn.Module):
    def __init__(self):
        super(Seqtoseq,self).__init__()
        self.encoder = nn.RNN(input_size=n_class,hidden_size=n_hideen,dropout=0.5)#构建两层RNN的模型,一层用来编码,一层用来解码
        self.decoder = nn.RNN(input_size=n_class,hidden_size=n_hideen,dropout=0.5)#如果对pytorchRNN参数有疑惑,请看解释3
        self.fc = nn.Linear(n_hideen,n_class)

    def forward(self, encoder_input,encoder_hidden,decoder_input):#训练前向传播
        # encoder_input: [max_len, batch_size, n_class]
        # decoder_input: [max_len, batch_size, n_class]
        encoder_input = encoder_input.transpose(0,1)#如果不了解为什么此处需要转置,请看解释4
        decoder_input = decoder_input.transpose(0,1)

        # encoder_state : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
        _,encoder_state = self.encoder(encoder_input,encoder_hidden)#RNN encoder-decoder结构是需要把encoderRNN的最后一次隐式输出hn送给decoder作为输入的。
        decoder_output,_ = self.decoder(decoder_input,encoder_state)
        # outputs : [max_len+1(=6), batch_size, num_directions(=1) * n_hidden(=128)]
        model = self.fc(decoder_output)
        # model : [max_len+1(=6), batch_size, n_class]
        return model

input_batch, output_batch, target_batch = make_batch(seq_data)#得到训练集

#构造模型,交叉熵,优化器,设置学习率为0.01
model = Seqtoseq()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

#训练
for echo in range(5000):
    hidden = Variable(torch.zeros(1, batch_size, n_hideen))
    optimizer.zero_grad()#在每一次开始训练之前,需要把梯度清零
    # input_batch : [batch_size, max_len, n_class]
    # output_batch : [batch_size, max_len+1 (becase of 'S' or 'E'), n_class]
    # target_batch : [batch_size, max_len+1, not one-hot
    output = model(input_batch, hidden, output_batch)
    output = output.transpose(0, 1)# [batch_size, max_len+1(=6), n_class]
    loss = 0
    for i in range(0,len(target_batch)):
        # output[i] : [max_len+1, n_class, target_batch[i] : max_len+1]
        loss += criterion(output[i], target_batch[i])
    if (echo + 1) % 1000 == 0:
        print("echo ==  ", "%04d" % (echo + 1), "loss ==  ", "{:.6f}".format(loss))
    loss.backward()#反向传播
    optimizer.step()

def translate(str):
    input_batch,output_batch,_ = make_batch([[str,'P'*len(str)]])
    hidden = Variable(torch.zeros(1,1,n_hideen))

    output = model(input_batch,hidden,output_batch)
    output.transpose(0,1)

    predict = output.data.max(2,keepdim = True)[1]#如果对max函数不了解,请见解释5
    decoder = [char_arr[i] for i in predict]

    end = decoder.index('E')
    answer = ''.join(decoder[:end])
    return answer.replace('P','')

print('test')
print('man ->', translate('man'))
print('mans ->', translate('mans'))
print('king ->', translate('king'))
print('black ->', translate('black'))
print('upp ->', translate('upp'))

解释1:

enumerate函数的使用,请见: https://blog.csdn.net/lwgkzl/article/details/88735271

解释2:

np.eye函数的用法,请见:https://blog.csdn.net/chixujohnny/article/details/51011931

解释3:

RNN的用法,请见:https://blog.csdn.net/lwgkzl/article/details/88717678

解释4:

在此处为啥需要转置呢?待分析

解释5:

max函数的使用,请见:https://blog.csdn.net/Z_lbj/article/details/79766690

你可能感兴趣的:(NLP)