chatbot---代码跑通

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# TODO : generating chat datafile from movie-dialogs corpus
# Author: zhaoliang
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
import itertools
DATA_PATH="/home/zhaoliang/project/chatbot/cornell movie-dialogs corpus/chatfile.txt"
ROOT_PATH="/home/zhaoliang/project/chatbot"
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
PAD_TOKEN=0
S_TOKEN=1 
E_TOKEN=2 
MIN_COUNT=3 # if the word appears in corpus is less than this value ,then the word will be discarded!
MAX_LEN=15 # if the length of a sentence is greater than this value,then the entire sentence pairs will be discarded!
BATCH_SIZE=5 
HIDDEN_SIZE=100
lr=0.0001
decoderlr=5
clip=50
teacher_forcing_ratio=0.2

encoder_n_layers=1
decoder_n_layers=1
(n_iteration,print_every,save_every)=(30000,200,5000)

#load data
pairs=[]
f=open(DATA_PATH,'r')
for line in f.readlines():
    pairs.append(line.split("$"))

class VOC:
    def __init__(self):
        self.word2index={}
        self.index2word={PAD_TOKEN: "PAD", S_TOKEN: "SOS", E_TOKEN: "EOS"}
        self.word2count={}
        self.index=3
        self.trimmed=False
    def addWord(self,word):
        if word not in self.word2index:
            self.word2index[word]=self.index
            self.index2word[self.index]=word
            self.word2count[word]=1
            self.index+=1
        else:
            self.word2count[word]+=1
            
    def addSentence(self,sentence):
        for word in sentence.split(' '):
            self.addWord(word)
    def trimVoc(self,MIN_COUNT):
        if self.trimmed:
            return
        self.trimmed=True
        w2i={}
        i2w={PAD_TOKEN: "PAD", S_TOKEN: "SOS", E_TOKEN: "EOS"}
        w2c={}
        print("before trimming there are %d words in vocabulary."%self.index)
        self.index=3
        for word,count in self.word2count.items():
            if count > MIN_COUNT:
                w2i[word]=self.index
                i2w[self.index]=word
                w2c[word]=count
                self.index+=1
        print("%d words has been trimmed."%self.index)
        self.word2index=w2i
        self.index2word=i2w
        self.word2count=w2c
voc=VOC()
def add2Voc(voc,pairs):
    for i,pair in enumerate(pairs):
        #print(pair)
        if len(pair)==2:
            voc.addSentence(pair[0])
            voc.addSentence(pair[1])
        else :
            print("this pair only have one sentence!\n",pair)
            pairs.pop(i)
            print("this sentence has been removed!")
    print("pairs have been added to voc!")
add2Voc(voc,pairs)

def inVoc(voc,sentence):
    for word in sentence.split(" "):
        if word.strip() not in voc.word2index:
            return False
    return True

def trim(voc,pairs):
    voc.trimVoc(MIN_COUNT)
    trimmedPairs=[]
    print("before trimming sentence,there are %d sentences pair in pairs."%len(pairs))
    for pair in pairs:
        if len(pair[0].split(" "))< MAX_LEN and len(pair[1].split(" "))1 else 0),bidirectional=True)
    def forward(self,sentenceBatch,lenthBatch,hidden=None):
        sentenceBatchEmbedding=self.embedding(sentenceBatch)
        packs=torch.nn.utils.rnn.pack_padded_sequence(sentenceBatchEmbedding,lenthBatch)
        outputs, hidden = self.gru(packs, hidden)
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        outputs=outputs[:,:,:HIDDEN_SIZE]+outputs[:,:,HIDDEN_SIZE:]
        return outputs,hidden

embedding=nn.Embedding(voc.index,HIDDEN_SIZE)
encoder=encoderRNN(embedding,0.2).to(device)

class Attention(nn.Module):
    def __init__(self,name):
        super(Attention,self).__init__()
        self.name=name
        if name not in ["dot","general","concat"]:
            raise ValueError(name,"Unsupported operation type!")
        elif name=="general":
            self.atte=nn.Linear(HIDDEN_SIZE,HIDDEN_SIZE)
        elif name=="general":
            self.atte=nn.Linear(HIDDEN_SIZE*2,HIDDEN_SIZE)
            self.v=torch.nn.Parameter(torch.tensor(HIDDEN_SIZE,dtype=torch.double))
    def forward(self,decoderOutput,encoderOutput):
        if self.name =="dot":
            attenWeight=torch.sum(decoderOutput*encoderOutput,dim=2)
        elif self.name == "general":
            energy=self.atte(encoderOutput)
            attenWeight=torch.sum(decoderOutput*energy,dim=2)
        else:
            energy=self.atte(torch.cat((decoderOutput.expand(encoderOutput.size(0),-1,-1),encoderOutput),dim=2)).tanh()
            attenWeight=torch.sum(self.v*energy,dim=2)
        attenWeight=attenWeight.t() #transpose  equivalent to torch.transpose(attenWeight,0,1)  
        return F.softmax(attenWeight,dim=1).unsqueeze(1)# batchSize*1*maxLen   it will multiply batchSize*maxLen*hiddensize  
class decoderRNN(nn.Module):
    def __init__(self,embedding,attnName,vocSize,dropout=0):
        super(decoderRNN,self).__init__()
        self.embedding=embedding
        self.dropout=nn.Dropout(dropout)
        self.gru=nn.GRU(HIDDEN_SIZE,HIDDEN_SIZE,decoder_n_layers,dropout=(dropout if decoder_n_layers>1 else 0))
        self.attention=Attention(attnName)
        self.concatContextAndOutput=nn.Linear(HIDDEN_SIZE*2,HIDDEN_SIZE)
        self.out=nn.Linear(HIDDEN_SIZE,vocSize)
    def forward(self,inputBatch,hiddenBatch,encoderOutput):
        inputBatch=self.embedding(inputBatch)
        inputBatch=self.dropout(inputBatch)
        output,hidden=self.gru(inputBatch,hiddenBatch)#1*batchSize*hiddenSize
        attenWeight=self.attention(output,encoderOutput)
        context=attenWeight.bmm(encoderOutput.transpose(0,1)).squeeze(1)#batchsize*1*hiddenSize
        mixture=self.concatContextAndOutput(torch.cat((context,output.squeeze(0)),dim=1)).tanh()
        out=self.out(mixture)
        wordIndex=F.softmax(out,dim=1)
        return wordIndex , hidden  #wordIndex [batchsize*vocSize]
decoder=decoderRNN(embedding,'general',voc.index,0.1).to(device)

def maskNLLLose(inp,target,mask):#do it after get all decoder output  2-dimentional
    nTotal=mask.sum().item()
    crossEntropy=-torch.log(torch.gather(inp,1,target.view(-1,1)))
    crossEntropy=crossEntropy.masked_select(mask).mean().to(device)
    return crossEntropy,nTotal
encoder_optimizer=optim.Adam(encoder.parameters(),lr=lr)
decoder_optimizer=optim.Adam(decoder.parameters(),lr=lr*decoderlr)

def train(inputTensor,lengthTensor,outputTensor,maskTensor,maxLen):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss=0
    print_losses=[]
    n_totals=0
    encoder_outputs,encoder_hidden=encoder(inputTensor,lengthTensor)
    decoder_input=torch.tensor([[S_TOKEN for _ in range(BATCH_SIZE)]],device=device,dtype=torch.long)
    decoder_hidden=encoder_hidden[:decoder_n_layers]
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    if use_teacher_forcing:
        for t in range(maxLen):
            decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
            decoder_input=outputTensor[t].view(1,-1)
            mask_loss,nTotal=maskNLLLose(decoder_output,outputTensor[t],maskTensor[t])
            loss+=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals+=nTotal
    else:
        for t in range(maxLen):
            decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
            _,topi=decoder_output.topk(1)
            decoder_input=torch.tensor([[topi[i][0] for i in range(BATCH_SIZE)]],device=device,dtype=torch.long)  # sequenceLen * batchSize
            mask_loss,nTotal=maskNLLLose(decoder_output,outputTensor[t],maskTensor[t])
            loss+=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals+=nTotal
    loss.backward()
    _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
    encoder_optimizer.step()
    decoder_optimizer.step()
    return sum(print_losses)/n_totals

    
def trainIters():
    training_batches=[getTrainBatch(voc,[random.choice(pairs) for _ in range(BATCH_SIZE)]) for _ in range(n_iteration)]
    print("Initializing...")
    start_iteration=1
    print_loss=0
    print("Start Training...")
    for i in range(start_iteration,n_iteration+1):
        inputTensor,lengthTensor,outputTensor,maskTensor,maxLen=training_batch=training_batches[i-1]
        loss=train(inputTensor,lengthTensor,outputTensor,maskTensor,maxLen)
        print_loss+=loss
        if i% print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: %d; Percent complete: %.1f; Average loss: %.4f"%(i,i/n_iteration*100,print_loss_avg))
            print_loss=0
        if i%save_every == 0:
            directory=os.path.join(ROOT_PATH,'chatbot_{}'.format(i))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({'iteratio':i,
            'en':encoder.state_dict(),
            'de':decoder.state_dict(),
            "en_opt":encoder_optimizer.state_dict(),
            "loss":loss,
            'voc_dict':voc.__dict__,
            'embedding':embedding.state_dict()},os.path.join(directory,'{}_{}.tar'.format(i,"checkpoint")))

trainIters()



        

你可能感兴趣的:(chatbot---代码跑通)