涉及的论文
Neural Conversational Model https://arxiv.org/abs/1506.05869
Luong attention mechanism(s) https://arxiv.org/abs/1508.04025
Sutskever et al. https://arxiv.org/abs/1409.3215
GRU Cho et al. https://arxiv.org/pdf/1406.1078v3.pdf
Bahdanau et al. https://arxiv.org/abs/1409.0473
使用的数据集
Corpus web https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
Corpus link http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
代码列表
chatbot_test.py
chatbot_train.py
corpus_dataset.py
vocabulary.py
graph.py
model.py
etc.py
main.py
chatbot_test.py
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import corpus_dataset
import graph
import etc
def run_test():
config = etc.config
voc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)
g = graph.CorpusGraph(config)
train_model = g.create_train_model(voc, "test")
g.evaluate_input(voc, train_model)
chatbot_train.py
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import corpus_dataset
import graph
import etc
def run_train():
config = etc.config
voc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)
g = graph.CorpusGraph(config)
print("Create model")
train_model = g.create_train_model(voc)
print("Starting Training!")
g.trainIters(voc, pairs, train_model)
# print("Starting evaluate!")
# g.evaluate_input(voc, train_model)
corpus_dataset.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : corpus_dataset.py
# Create date : 2019-01-16 11:16
# Modified date : 2019-02-02 14:55
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import os
import re
import csv
import codecs
import unicodedata
import vocabulary
def _check_is_have_file(file_name):
return os.path.exists(file_name)
def _filter_pair(p, max_length):
return len(p[0].split(' ')) < max_length and len(p[1].split(' ')) < max_length
def _filter_pairs(pairs, max_length):
return [pair for pair in pairs if _filter_pair(pair, max_length)]
def _read_vocabulary(datafile, corpus_name):
print("Reading lines...")
lines = open(datafile, encoding='utf-8'). read().strip().split('\n')
pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]
voc = vocabulary.Voc(corpus_name)
return voc, pairs
def _unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
def _get_delimiter(config):
delimiter = config["delimiter"]
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
return delimiter
def _get_object(line, fields):
values = line.split(" +++$+++ ")
obj = {}
for i, field in enumerate(fields):
obj[field] = values[i]
return obj
def _load_lines(config):
lines_file_name = config["lines_file_name"]
corpus_path = config["corpus_path"]
lines_file_full_path = "%s/%s" % (corpus_path, lines_file_name)
fields = config["movie_lines_fields"]
lines = {}
f = open(lines_file_full_path, 'r', encoding='iso-8859-1')
for line in f:
line_obj = _get_object(line, fields)
lines[line_obj['lineID']] = line_obj
f.close()
return lines
def _cellect_lines(conv_obj, lines):
# Convert string to list (conv_obj["utteranceIDs"] == "['L598485', 'L598486', ...]")
line_ids = eval(conv_obj["utteranceIDs"])
# Reassemble lines
conv_obj["lines"] = []
for line_id in line_ids:
conv_obj["lines"].append(lines[line_id])
return conv_obj
def _load_conversations(lines, config):
conversations = []
corpus_path = config["corpus_path"]
conversation_file_name = config["conversation_file_name"]
conversation_file_full_path = "%s/%s" % (corpus_path, conversation_file_name)
fields = config["movie_conversations_fields"]
f = open(conversation_file_full_path, 'r', encoding='iso-8859-1')
for line in f:
conv_obj = _get_object(line, fields)
conv_obj = _cellect_lines(conv_obj, lines)
conversations.append(conv_obj)
f.close()
return conversations
def _get_conversations(config):
lines = {}
conversations = []
lines = _load_lines(config)
print("lines count:", len(lines))
conversations = _load_conversations(lines, config)
print("conversations count:", len(conversations))
return conversations
def _extract_sentence_pairs(conversations):
pairs = []
for conversation in conversations:
for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)
inputLine = conversation["lines"][i]["text"].strip()
targetLine = conversation["lines"][i+1]["text"].strip()
# Filter wrong samples (if one of the lists is empty)
if inputLine and targetLine:
pairs.append([inputLine, targetLine])
return pairs
def _load_formatted_data(config):
max_length = config["max_length"]
corpus_name = config["corpus_name"]
formatted_file_full_path = get_formatted_file_full_path(config)
print("Start preparing training data ...")
voc, pairs = _read_vocabulary(formatted_file_full_path, corpus_name)
print("Read {!s} sentence pairs".format(len(pairs)))
pairs = _filter_pairs(pairs, max_length)
print("Trimmed to {!s} sentence pairs".format(len(pairs)))
for pair in pairs:
voc.addSentence(pair[0])
voc.addSentence(pair[1])
print("Counted words:", voc.num_words)
return voc, pairs
def _trim_rare_words(voc, pairs, min_count):
voc.trim(min_count)
keep_pairs = []
for pair in pairs:
input_sentence = pair[0]
output_sentence = pair[1]
keep_input = True
keep_output = True
for word in input_sentence.split(' '):
if word not in voc.word2index:
keep_input = False
break
for word in output_sentence.split(' '):
if word not in voc.word2index:
keep_output = False
break
if keep_input and keep_output:
keep_pairs.append(pair)
print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
return keep_pairs
def _write_newly_formatted_file(config):
formatted_file_full_path = get_formatted_file_full_path(config)
if not _check_is_have_file(formatted_file_full_path):
delimiter = _get_delimiter(config)
conversations = _get_conversations(config)
outputfile = open(formatted_file_full_path, 'w', encoding='utf-8')
pairs = _extract_sentence_pairs(conversations)
print("pairs count:", len(pairs))
writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
print("\nWriting newly formatted file...")
for pair in pairs:
writer.writerow(pair)
else:
print("%s already has the formatted file,so we do not write" % formatted_file_full_path)
def load_vocabulary_and_pairs(config):
_write_newly_formatted_file(config)
voc, pairs = _load_formatted_data(config)
pairs = _trim_rare_words(voc, pairs, config["min_count"])
return voc, pairs
def get_formatted_file_full_path(config):
formatted_file_name = config["formatted_file_name"]
corpus_path = config["corpus_path"]
formatted_file_full_path = "%s/%s" % (corpus_path, formatted_file_name)
return formatted_file_full_path
def normalize_string(s):
s = _unicode_to_ascii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
s = re.sub(r"\s+", r" ", s).strip()
return s
vocabulary.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : vocabulary.py
# Create date : 2019-01-16 11:21
# Modified date : 2019-02-02 13:37
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
PAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence token
class Voc:
def __init__(self, name):
self.name = name
self.trimmed = False
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count SOS, EOS, PAD
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.num_words
self.word2count[word] = 1
self.index2word[self.num_words] = word
self.num_words += 1
else:
self.word2count[word] += 1
def trim(self, min_count):
if self.trimmed:
return
self.trimmed = True
keep_words = []
for k, v in self.word2count.items():
if v >= min_count:
keep_words.append(k)
print('keep_words {} / {} = {:.4f}'.format(
len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
))
self.word2index = {}
self.word2count = {}
self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
self.num_words = 3 # Count default tokens
for word in keep_words:
self.addWord(word)
graph.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : graph.py
# Create date : 2019-01-16 11:44
# Modified date : 2019-02-02 14:55
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import os
import itertools
import random
import torch
import torch.nn as nn
from torch import optim
import vocabulary
import model
import corpus_dataset
def _get_training_batches(voc, pairs, batch_size, n_iteration):
training_batches = []
for i in range(n_iteration):
lt = [random.choice(pairs) for _ in range(batch_size)]
batch = _batch2TrainData(voc, lt)
training_batches.append(batch)
return training_batches
def _zero_padding(l, fillvalue=vocabulary.PAD_token):
return list(itertools.zip_longest(*l, fillvalue=fillvalue))
def _binary_matrix(lt):
m = []
for i, seq in enumerate(lt):
m.append([])
for token in seq:
if token == vocabulary.PAD_token:
m[i].append(0)
else:
m[i].append(1)
return m
def _get_indexes_batch(lt, voc):
indexes_batch = [_indexes_from_sentence(voc, sentence) for sentence in lt]
return indexes_batch
def _input_var(batch, voc):
indexes_batch = _get_indexes_batch(batch, voc)
padList = _zero_padding(indexes_batch)
variable = torch.LongTensor(padList)
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
return variable, lengths
def _output_var(batch, voc):
indexes_batch = _get_indexes_batch(batch, voc)
padList = _zero_padding(indexes_batch)
variable = torch.LongTensor(padList)
max_target_len = max([len(indexes) for indexes in indexes_batch])
mask = _binary_matrix(padList)
mask = torch.ByteTensor(mask)
return variable, mask, max_target_len
def _indexes_from_sentence(voc, sentence):
#return [voc.word2index[word] for word in sentence.split(' ')] + [vocabulary.EOS_token]
index_lt = []
for word in sentence.split(' '):
i = voc.word2index[word]
index_lt.append(i)
index_lt.append(vocabulary.EOS_token)
return index_lt
def _batch2TrainData(voc, pair_batch):
pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
input_batch, output_batch = [], []
for pair in pair_batch:
input_batch.append(pair[0])
output_batch.append(pair[1])
input_variable, lengths = _input_var(input_batch, voc)
target_variable, mask, max_target_len = _output_var(output_batch, voc)
return input_variable, lengths, target_variable, mask, max_target_len
def _maskNLLLoss(inp, target, mask, device):
nTotal = mask.sum()
crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
loss = crossEntropy.masked_select(mask).mean()
loss = loss.to(device)
return loss, nTotal.item()
class CorpusGraph(nn.Module):
def __init__(self, config):
super(CorpusGraph, self).__init__()
self.model_name = config["model_name"]
self.save_dir = config["save_dir"]
self.corpus_name = config["corpus_name"]
self.encoder_n_layers = config["encoder_n_layers"]
self.decoder_n_layers = config["decoder_n_layers"]
self.hidden_size = config["hidden_size"]
self.checkpoint_iter = config["checkpoint_iter"]
self.learning_rate = config["learning_rate"]
self.decoder_learning_ratio = config["decoder_learning_ratio"]
self.dropout = config["dropout"]
self.attn_model = config["attn_model"]
self.device = config["device"]
self.print_every = config["print_every"]
self.save_every = config["save_every"]
self.n_iteration = config["n_iteration"]
self.batch_size = config["batch_size"]
self.clip = config["clip"]
self.max_length = config["max_length"]
self.teacher_forcing_ratio = config["teacher_forcing_ratio"]
self.train_load_checkpoint_file = config["train_load_checkpoint_file"]
def _evaluate(self, voc, sentence, train_model):
encoder = train_model["encoder"]
decoder = train_model["decoder"]
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()
searcher = model.GreedySearchDecoder(encoder, decoder)
indexes_batch = [_indexes_from_sentence(voc, sentence)]
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)
input_batch = input_batch.to(self.device)
lengths = lengths.to(self.device)
tokens, scores = searcher(input_batch, lengths, self.max_length, self.device)
decoded_words = [voc.index2word[token.item()] for token in tokens]
return decoded_words
def _choose_use_teacher_forcing(self):
return True if random.random() < self.teacher_forcing_ratio else False
def _train_step(self, decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask, max_target_len):
loss = 0
print_losses = []
n_totals = 0
for t in range(max_target_len):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
if self._choose_use_teacher_forcing():
decoder_input = target_variable[t].view(1, -1)
else:
_, topi = decoder_output.topk(1)
decoder_input = torch.LongTensor([[topi[i][0] for i in range(self.batch_size)]])
decoder_input = decoder_input.to(self.device)
mask_loss, nTotal = _maskNLLLoss(decoder_output, target_variable[t], mask[t], self.device)
loss += mask_loss
print_losses.append(mask_loss.item() * nTotal)
n_totals += nTotal
return loss, print_losses, n_totals
def _train_init(self, input_variable, lengths, target_variable, mask, train_model):
encoder = train_model["encoder"]
decoder = train_model["decoder"]
encoder_optimizer = train_model["encoder_optimizer"]
decoder_optimizer = train_model["decoder_optimizer"]
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
input_variable = input_variable.to(self.device)
lengths = lengths.to(self.device)
target_variable = target_variable.to(self.device)
mask = mask.to(self.device)
encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
decoder_input = torch.LongTensor([[vocabulary.SOS_token for _ in range(self.batch_size)]])
decoder_input = decoder_input.to(self.device)
decoder_hidden = encoder_hidden[:decoder.n_layers]
return decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask
def _train_backward(self, loss, train_model):
encoder = train_model["encoder"]
decoder = train_model["decoder"]
encoder_optimizer = train_model["encoder_optimizer"]
decoder_optimizer = train_model["decoder_optimizer"]
loss.backward()
_ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), self.clip)
_ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), self.clip)
encoder_optimizer.step()
decoder_optimizer.step()
def _train(self, input_variable, lengths, target_variable, mask, max_target_len, train_model):
decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask = self._train_init(input_variable, lengths, target_variable, mask, train_model)
loss, print_losses, n_totals = self._train_step(decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask, max_target_len)
self._train_backward(loss, train_model)
return sum(print_losses) / n_totals
def _save_model_dict(self, train_model, iteration, voc, loss):
model_dict = self._get_model_dict(train_model, iteration, voc, loss)
checkpoint_file_full_path = self._get_checkpoint_file_full_name()
torch.save(model_dict, checkpoint_file_full_path)
def _show_batches(self, batches):
input_variable, lengths, target_variable, mask, max_target_len = batches
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)
def _show_train_state(self, print_loss, iteration):
print_loss_avg = print_loss / self.print_every
print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / self.n_iteration * 100, print_loss_avg))
print_loss = 0
return print_loss
def _get_model_dict(self, train_model, iteration, voc, loss):
model_dict = {}
model_dict["en"] = train_model["encoder"].state_dict()
model_dict["de"] = train_model["decoder"].state_dict()
model_dict["en_opt"] = train_model["encoder_optimizer"].state_dict()
model_dict["de_opt"] = train_model["decoder_optimizer"].state_dict()
model_dict["embedding"] = train_model["embedding"].state_dict()
model_dict["iteration"] = iteration
model_dict["loss"] = loss
model_dict["voc_dict"] = voc.__dict__
return model_dict
def _load_checkpoint(self, train_model, voc, checkpoint):
train_model["encoder"].load_state_dict(checkpoint['en'])
train_model["decoder"].load_state_dict(checkpoint['de'])
train_model["encoder_optimizer"].load_state_dict(checkpoint['en_opt'])
train_model["decoder_optimizer"].load_state_dict(checkpoint['de_opt'])
train_model["embedding"].load_state_dict(checkpoint['embedding'])
voc.__dict__ = checkpoint['voc_dict']
train_model["iteration"] = checkpoint["iteration"]
return train_model
def _train_load_checkpoint(self, train_model, voc):
loadFilename = self._get_checkpoint_file_full_name()
if os.path.exists(loadFilename) and self.train_load_checkpoint_file:
checkpoint = torch.load(loadFilename)
train_model = self._load_checkpoint(train_model, voc, checkpoint)
return train_model
def _test_load_checkpoint(self, train_model, voc):
loadFilename = self._get_checkpoint_file_full_name()
if os.path.exists(loadFilename) and self.train_load_checkpoint_file:
checkpoint = torch.load(loadFilename)
# If loading a model trained on GPU to CPU
checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
train_model = self._load_checkpoint(train_model, voc, checkpoint)
return train_model
def _get_save_directory(self):
directory = os.path.join(self.save_dir,
self.model_name,
self.corpus_name,
'{}-{}_{}'.format(self.encoder_n_layers,
self.decoder_n_layers,
self.hidden_size))
if not os.path.exists(directory):
os.makedirs(directory)
return directory
def _get_checkpoint_file_full_name(self):
directory = self._get_save_directory()
checkpoint_file_name = "checkpoint.tar"
checkpoint_file_full_name = "%s/%s" % (directory, checkpoint_file_name)
return checkpoint_file_full_name
def create_train_model(self, voc, status="train"):
embedding = nn.Embedding(voc.num_words, self.hidden_size)
encoder = model.EncoderRNN(self.hidden_size, embedding, self.encoder_n_layers, self.dropout)
encoder = encoder.to(self.device)
decoder = model.LuongAttnDecoderRNN(self.attn_model, embedding, self.hidden_size, voc.num_words, self.decoder_n_layers, self.dropout)
decoder = decoder.to(self.device)
#Ensure dropout layers are in train mode
encoder.train()
decoder.train()
encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), self.learning_rate*self.decoder_learning_ratio)
train_model = {}
train_model["encoder"] = encoder
train_model["decoder"] = decoder
train_model["encoder_optimizer"] = encoder_optimizer
train_model["decoder_optimizer"] = decoder_optimizer
train_model["embedding"] = embedding
train_model["iteration"] = 0
if status == "train":
train_model = self._train_load_checkpoint(train_model, voc)
else:
train_model = self._test_load_checkpoint(train_model, voc)
return train_model
def trainIters(self, voc, pairs, train_model):
training_batches = _get_training_batches(voc, pairs, self.batch_size, self.n_iteration)
print_loss = 0
base_iteration = train_model['iteration'] + 1
start_iteration = 1
for iteration in range(start_iteration, self.n_iteration + 1):
training_batch = training_batches[iteration - 1]
#self._show_batches(training_batch)
input_variable, lengths, target_variable, mask, max_target_len = training_batch
loss = self._train(input_variable, lengths, target_variable, mask, max_target_len, train_model)
print_loss += loss
cur_iteration = base_iteration + iteration
if iteration % self.print_every == 0:
print_loss = self._show_train_state(print_loss, cur_iteration)
if iteration % self.save_every == 0:
self._save_model_dict(train_model, cur_iteration, voc, loss)
def evaluate_input(self, voc, train_model):
input_sentence = ''
while(1):
try:
input_sentence = input('> ')
if input_sentence == 'q' or input_sentence == 'quit': break
input_sentence = corpus_dataset.normalize_string(input_sentence)
output_words = self._evaluate(voc, input_sentence, train_model)
output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
print('Bot:', ' '.join(output_words))
except KeyError:
print("Error: Encountered unknown word.")
model.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : model.py
# Create date : 2019-01-16 11:38
# Modified date : 2019-02-02 14:50
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import vocabulary
class EncoderRNN(nn.Module):
def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
super(EncoderRNN, self).__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.embedding = embedding
self.gru = nn.GRU(hidden_size,
hidden_size,
n_layers,
dropout=(0 if n_layers == 1 else dropout),
bidirectional=True)
def forward(self, input_seq, input_lengths, hidden=None):
embedded = self.embedding(input_seq)
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
outputs, hidden = self.gru(packed, hidden)
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
outputs = outputs[ :, :, :self.hidden_size] + outputs[ :, :, self.hidden_size:]
return outputs, hidden
class Attn(torch.nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
if self.method not in ['dot', 'general', 'concat']:
raise ValueError(self.method, "is not an appropriate attention method.")
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = torch.nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)
self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))
def dot_score(self, hidden, encoder_output):
return torch.sum(hidden * encoder_output, dim=2)
def general_score(self, hidden, encoder_output):
energy = self.attn(encoder_output)
return torch.sum(hidden * energy, dim=2)
def concat_score(self, hidden, encoder_output):
energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
return torch.sum(self.v * energy, dim=2)
def forward(self, hidden, encoder_outputs):
if self.method == 'general':
attn_energies = self.general_score(hidden, encoder_outputs)
elif self.method == 'concat':
attn_energies = self.concat_score(hidden, encoder_outputs)
elif self.method == 'dot':
attn_energies = self.dot_score(hidden, encoder_outputs)
attn_energies = attn_energies.t()
return F.softmax(attn_energies, dim=1).unsqueeze(1)
class LuongAttnDecoderRNN(nn.Module):
def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
super(LuongAttnDecoderRNN, self).__init__()
self.attn_model = attn_model
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.dropout = dropout
self.embedding = embedding
self.embedding_dropout = nn.Dropout(dropout)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
self.concat = nn.Linear(hidden_size * 2, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
self.attn = Attn(attn_model, hidden_size)
def forward(self, input_step, last_hidden, encoder_outputs):
embedded = self.embedding(input_step)
embedded = self.embedding_dropout(embedded)
rnn_output, hidden = self.gru(embedded, last_hidden)
attn_weights = self.attn(rnn_output, encoder_outputs)
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
rnn_output = rnn_output.squeeze(0)
context = context.squeeze(1)
concat_input = torch.cat((rnn_output, context), 1)
concat_output = torch.tanh(self.concat(concat_input))
output = self.out(concat_output)
output = F.softmax(output, dim=1)
return output, hidden
class GreedySearchDecoder(nn.Module):
def __init__(self, encoder, decoder):
super(GreedySearchDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, input_seq, input_length, max_length,device):
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
decoder_hidden = encoder_hidden[:self.decoder.n_layers]
decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * vocabulary.SOS_token
all_tokens = torch.zeros([0], device=device, dtype=torch.long)
all_scores = torch.zeros([0], device=device)
for _ in range(max_length):
decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
all_scores = torch.cat((all_scores, decoder_scores), dim=0)
decoder_input = torch.unsqueeze(decoder_input, 0)
return all_tokens, all_scores
etc.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : etc.py
# Create date : 2019-01-17 22:50
# Modified date : 2019-02-02 14:10
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
import torch
config = {}
config["corpus_name"] = "cornell movie-dialogs corpus"
config["corpus_path"] = "./data/%s" % config["corpus_name"]
config["delimiter"] = '\t'
config["formatted_file_name"] = "formatted_movie_lines.txt"
config["conversation_file_name"] = "movie_conversations.txt"
config["lines_file_name"] = "movie_lines.txt"
config["movie_lines_fields"] = ["lineID", "characterID", "movieID", "character", "text"]
config["movie_conversations_fields"] = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
config["model_name"] = 'cb_model'
config["attn_model"] = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
config["hidden_size"] = 500
config["encoder_n_layers"] = 2
config["decoder_n_layers"] = 2
config["dropout"] = 0.1
config["print_every"] = 20
config["save_every"] = 500
config["n_iteration"] = 1000
config["encoder_n_layers"] = 2
config["decoder_n_layers"] = 2
config["clip"] = 50.0
config["learning_rate"] = 0.0001
config["decoder_learning_ratio"] = 5.0
config["batch_size"] = 64
config["save_dir"] = "./data/save"
config["checkpoint_iter"] = 4000
config["min_count"] = 3 # Minimum word count threshold for trimming
config["max_length"] = 10
config["teacher_forcing_ratio"] = 1.0
config["train_load_checkpoint_file"] = True
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
config["device"] = device
main.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2019-02-02 13:44
# Modified date : 2019-02-02 13:45
# Author : DARREN
# Describe : not set
# Email : [email protected]
#####################################
from __future__ import division
from __future__ import print_function
from chatbot_train import run_train
from chatbot_test import run_test
def run():
run_train()
run_test()
run()
github:
https://github.com/darr/chatbot