#!/usr/bin/env python
# -*- coding: utf-8 -*-
# TODO : generating chat datafile from movie-dialogs corpus
# Author: zhaoliang
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import os
import itertools
DATA_PATH="/home/zhaoliang/project/chatbot/cornell movie-dialogs corpus/chatfile.txt"
ROOT_PATH="/home/zhaoliang/project/chatbot"
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
PAD_TOKEN=0
S_TOKEN=1
E_TOKEN=2
MIN_COUNT=3 # if the word appears in corpus is less than this value ,then the word will be discarded!
MAX_LEN=15 # if the length of a sentence is greater than this value,then the entire sentence pairs will be discarded!
BATCH_SIZE=5
HIDDEN_SIZE=100
lr=0.0001
decoderlr=5
clip=50
teacher_forcing_ratio=0.2
encoder_n_layers=1
decoder_n_layers=1
(n_iteration,print_every,save_every)=(30000,200,5000)
#load data
pairs=[]
f=open(DATA_PATH,'r')
for line in f.readlines():
pairs.append(line.split("$"))
class VOC:
def __init__(self):
self.word2index={}
self.index2word={PAD_TOKEN: "PAD", S_TOKEN: "SOS", E_TOKEN: "EOS"}
self.word2count={}
self.index=3
self.trimmed=False
def addWord(self,word):
if word not in self.word2index:
self.word2index[word]=self.index
self.index2word[self.index]=word
self.word2count[word]=1
self.index+=1
else:
self.word2count[word]+=1
def addSentence(self,sentence):
for word in sentence.split(' '):
self.addWord(word)
def trimVoc(self,MIN_COUNT):
if self.trimmed:
return
self.trimmed=True
w2i={}
i2w={PAD_TOKEN: "PAD", S_TOKEN: "SOS", E_TOKEN: "EOS"}
w2c={}
print("before trimming there are %d words in vocabulary."%self.index)
self.index=3
for word,count in self.word2count.items():
if count > MIN_COUNT:
w2i[word]=self.index
i2w[self.index]=word
w2c[word]=count
self.index+=1
print("%d words has been trimmed."%self.index)
self.word2index=w2i
self.index2word=i2w
self.word2count=w2c
voc=VOC()
def add2Voc(voc,pairs):
for i,pair in enumerate(pairs):
#print(pair)
if len(pair)==2:
voc.addSentence(pair[0])
voc.addSentence(pair[1])
else :
print("this pair only have one sentence!\n",pair)
pairs.pop(i)
print("this sentence has been removed!")
print("pairs have been added to voc!")
add2Voc(voc,pairs)
def inVoc(voc,sentence):
for word in sentence.split(" "):
if word.strip() not in voc.word2index:
return False
return True
def trim(voc,pairs):
voc.trimVoc(MIN_COUNT)
trimmedPairs=[]
print("before trimming sentence,there are %d sentences pair in pairs."%len(pairs))
for pair in pairs:
if len(pair[0].split(" "))< MAX_LEN and len(pair[1].split(" "))1 else 0),bidirectional=True)
def forward(self,sentenceBatch,lenthBatch,hidden=None):
sentenceBatchEmbedding=self.embedding(sentenceBatch)
packs=torch.nn.utils.rnn.pack_padded_sequence(sentenceBatchEmbedding,lenthBatch)
outputs, hidden = self.gru(packs, hidden)
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
outputs=outputs[:,:,:HIDDEN_SIZE]+outputs[:,:,HIDDEN_SIZE:]
return outputs,hidden
embedding=nn.Embedding(voc.index,HIDDEN_SIZE)
encoder=encoderRNN(embedding,0.2).to(device)
class Attention(nn.Module):
def __init__(self,name):
super(Attention,self).__init__()
self.name=name
if name not in ["dot","general","concat"]:
raise ValueError(name,"Unsupported operation type!")
elif name=="general":
self.atte=nn.Linear(HIDDEN_SIZE,HIDDEN_SIZE)
elif name=="general":
self.atte=nn.Linear(HIDDEN_SIZE*2,HIDDEN_SIZE)
self.v=torch.nn.Parameter(torch.tensor(HIDDEN_SIZE,dtype=torch.double))
def forward(self,decoderOutput,encoderOutput):
if self.name =="dot":
attenWeight=torch.sum(decoderOutput*encoderOutput,dim=2)
elif self.name == "general":
energy=self.atte(encoderOutput)
attenWeight=torch.sum(decoderOutput*energy,dim=2)
else:
energy=self.atte(torch.cat((decoderOutput.expand(encoderOutput.size(0),-1,-1),encoderOutput),dim=2)).tanh()
attenWeight=torch.sum(self.v*energy,dim=2)
attenWeight=attenWeight.t() #transpose equivalent to torch.transpose(attenWeight,0,1)
return F.softmax(attenWeight,dim=1).unsqueeze(1)# batchSize*1*maxLen it will multiply batchSize*maxLen*hiddensize
class decoderRNN(nn.Module):
def __init__(self,embedding,attnName,vocSize,dropout=0):
super(decoderRNN,self).__init__()
self.embedding=embedding
self.dropout=nn.Dropout(dropout)
self.gru=nn.GRU(HIDDEN_SIZE,HIDDEN_SIZE,decoder_n_layers,dropout=(dropout if decoder_n_layers>1 else 0))
self.attention=Attention(attnName)
self.concatContextAndOutput=nn.Linear(HIDDEN_SIZE*2,HIDDEN_SIZE)
self.out=nn.Linear(HIDDEN_SIZE,vocSize)
def forward(self,inputBatch,hiddenBatch,encoderOutput):
inputBatch=self.embedding(inputBatch)
inputBatch=self.dropout(inputBatch)
output,hidden=self.gru(inputBatch,hiddenBatch)#1*batchSize*hiddenSize
attenWeight=self.attention(output,encoderOutput)
context=attenWeight.bmm(encoderOutput.transpose(0,1)).squeeze(1)#batchsize*1*hiddenSize
mixture=self.concatContextAndOutput(torch.cat((context,output.squeeze(0)),dim=1)).tanh()
out=self.out(mixture)
wordIndex=F.softmax(out,dim=1)
return wordIndex , hidden #wordIndex [batchsize*vocSize]
decoder=decoderRNN(embedding,'general',voc.index,0.1).to(device)
def maskNLLLose(inp,target,mask):#do it after get all decoder output 2-dimentional
nTotal=mask.sum().item()
crossEntropy=-torch.log(torch.gather(inp,1,target.view(-1,1)))
crossEntropy=crossEntropy.masked_select(mask).mean().to(device)
return crossEntropy,nTotal
encoder_optimizer=optim.Adam(encoder.parameters(),lr=lr)
decoder_optimizer=optim.Adam(decoder.parameters(),lr=lr*decoderlr)
def train(inputTensor,lengthTensor,outputTensor,maskTensor,maxLen):
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
loss=0
print_losses=[]
n_totals=0
encoder_outputs,encoder_hidden=encoder(inputTensor,lengthTensor)
decoder_input=torch.tensor([[S_TOKEN for _ in range(BATCH_SIZE)]],device=device,dtype=torch.long)
decoder_hidden=encoder_hidden[:decoder_n_layers]
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
if use_teacher_forcing:
for t in range(maxLen):
decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
decoder_input=outputTensor[t].view(1,-1)
mask_loss,nTotal=maskNLLLose(decoder_output,outputTensor[t],maskTensor[t])
loss+=mask_loss
print_losses.append(mask_loss.item()*nTotal)
n_totals+=nTotal
else:
for t in range(maxLen):
decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
_,topi=decoder_output.topk(1)
decoder_input=torch.tensor([[topi[i][0] for i in range(BATCH_SIZE)]],device=device,dtype=torch.long) # sequenceLen * batchSize
mask_loss,nTotal=maskNLLLose(decoder_output,outputTensor[t],maskTensor[t])
loss+=mask_loss
print_losses.append(mask_loss.item()*nTotal)
n_totals+=nTotal
loss.backward()
_ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
_ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)
encoder_optimizer.step()
decoder_optimizer.step()
return sum(print_losses)/n_totals
def trainIters():
training_batches=[getTrainBatch(voc,[random.choice(pairs) for _ in range(BATCH_SIZE)]) for _ in range(n_iteration)]
print("Initializing...")
start_iteration=1
print_loss=0
print("Start Training...")
for i in range(start_iteration,n_iteration+1):
inputTensor,lengthTensor,outputTensor,maskTensor,maxLen=training_batch=training_batches[i-1]
loss=train(inputTensor,lengthTensor,outputTensor,maskTensor,maxLen)
print_loss+=loss
if i% print_every == 0:
print_loss_avg = print_loss / print_every
print("Iteration: %d; Percent complete: %.1f; Average loss: %.4f"%(i,i/n_iteration*100,print_loss_avg))
print_loss=0
if i%save_every == 0:
directory=os.path.join(ROOT_PATH,'chatbot_{}'.format(i))
if not os.path.exists(directory):
os.makedirs(directory)
torch.save({'iteratio':i,
'en':encoder.state_dict(),
'de':decoder.state_dict(),
"en_opt":encoder_optimizer.state_dict(),
"loss":loss,
'voc_dict':voc.__dict__,
'embedding':embedding.state_dict()},os.path.join(directory,'{}_{}.tar'.format(i,"checkpoint")))
trainIters()