config.py
import pickle
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
"""word2sequence"""
chatbot_train_batch_size = 200
chatbot_test_batch_size = 300
input_ws = pickle.load(open('./model/ws_input.pkl', 'rb'))
target_ws = pickle.load(open('./model/ws_target.pkl' ,'rb'))
chatbot_input_max_len = 20
chatbot_target_max_len = 30
"""Encoder"""
chatbot_encoder_embedding_dim = 300
chatbot_encoder_hidden_size = 128
chatbot_encoder_numlayers = 2
chatbot_encoder_bidirectional = True
# RNN中 若要添加dropout, num_layer >= 2
chatbot_encoder_dropout = 0.3
"""Decoder"""
chatbot_decoder_embedding_dim = 300
chatbot_decoder_numlayers = 1
teacher_forcing = 0.5
"""beam search"""
beam_width = 2
数据集采用小黄鸡语料和微博等语料进行基本处理,后采用按字切分的方式,得到input.txt,target.txt
可以参考
input.txt
呵 呵
不 是
怎 么 了
开 心 点 哈 , 一 切 都 会 好 起 来
我 还 喜 欢 她 , 怎 么 办
短 信
你 知 道 谁 么
许 兵 是 谁
这 么 假
许 兵 是 傻 逼
许 兵 是 谁
许 兵 是 谁
许 兵 是 谁
许 兵 到 底 是 谁
尼 玛 , 许 兵 到 底 是 谁
小 黄 鸭 , 你 有 女 朋 友 么
那 你 有 男 朋 友 么
那 你 在 哪
target.txt
是 王 若 猫 的 。
那 是 什 么 。
我 很 难 过 , 安 慰 我 。
嗯 , 会 的
我 帮 你 告 诉 她 , 发 短 信 还 是 打 电 话 。
嗯 嗯 , 我 也 相 信
肯 定 不 是 我 , 是 阮 德 培
吴 院 四 班 小 帅 哥
三 鹿 奶 粉 也 假 , 不 一 样 的 卖 啊
被 你 发 现 了 。
是 我 善 良 可 爱 的 主 人 的 老 公 啊
是 穆 森 的 老 婆 啊
奇 葩
不 要 凶 我 , 应 该 大 概 也 许 是 叶 祺 吧 。
我 , 我 不 会 告 诉 你 我 是 小 澜 的 。
老 娘 是 女 的 。
没 有 呢 , 我 只 要 主 人 一 个 人 疼 爱 我 嘛 。
我 无 聊
分词方法实现 cut_sentence.py
import string
import jieba
import jieba.posseg as psg
import logging
# 停用词和自定义词典路径
stopwords_path = './corpus/stopwords.txt'
keywords_path = './corpus/keywords.txt'
# 英文字母
letters = string.ascii_lowercase
# 关闭jieba的日志
jieba.setLogLevel(logging.INFO)
# 读取所有停用词到列表
stop_words = [i.strip() for i in open(stopwords_path, encoding = 'utf-8').readline()]
def cut(sentence, by_word = False, use_stopwords = False, use_seg = False):
"""
分词方法
:param sentence: 待分词的句子
:param by_word: 按字切分
:param use_stopwords: 使用停用词
:param use_seg: 返回词性
:return:
"""
if by_word:
return cut_sentence_by_word(sentence)
else:
return cut_sentence(sentence, use_stopwords, use_seg)
def cut_sentence(sentence, use_stopwords, use_seg):
if use_seg:
# 使用psg.lcut进行切分,返回[(i.word, i.flag)...]
result = psg.lcut(sentence)
if use_stopwords:
result = [i for i in result if i[0] not in stop_words]
else:
result = jieba.lcut(sentence)
return result
def cut_sentence_by_word(sentence):
"""
按字进行切分
:param sentence: 待分词的语句
:return:
"""
temp = '' # 用于拼接英文字符
result = [] # 保存结果
# 按字遍历
for word in sentence:
# 判断是否是英文字母
if word in letters:
temp += word
# 若遇到非英文字符有两种情况
# 1.temp = ''意味当前是个汉字,将word直接存入result中
# 2.temp != '' 意味着拼接完了一个单词,遇到了当前汉字,需要将temp存入reslut,并置空
else:
if len(temp) > 0:
result.append(temp)
temp = ''
result.append(word)
# 当遍历完所有字后,最后一个字母可能存储在temp中
if len(temp) > 0:
result.append(temp)
return result
if __name__ == '__main__':
sentence = '今天天气好热a'
res1 = cut(sentence, by_word = True)
res2 = cut(sentence, by_word = False, use_seg = True)
res3 = cut(sentence, by_word = False, use_seg = False)
print(res1)
print(res2)
print(res3)
word2sequence.py
class Word2Sequence(object):
PAD_TAG = ''
UNK_TAG = ''
SOS_TAG = ''
EOS_TAG = ''
PAD = 0
UNK = 1
SOS = 2
EOS = 3
def __init__(self):
# 词表字典
self.dict = {
self.PAD_TAG:self.PAD,
self.UNK_TAG:self.UNK,
self.SOS_TAG:self.SOS,
self.EOS_TAG:self.EOS
}
# 统计词频用字典
self.count = {
}
def fit(self, sentence):
"""
统计每句话中的词频
:param sentence: 经过分词后的句子
:return:
"""
for word in sentence:
self.count[word] = self.count.get(word, 0) + 1
def build_vocab(self, min = 5, max = None, max_features = None):
"""
构建词表
:param min: 最小词频
:param max: 最大词频
:param max_features: 最多特征个数
:return:
"""
# 注意两个条件都有等号
if min is not None:
self.count = {
k : v for k, v in self.count.items() if v >= min}
if max is not None:
self.count = {
k : v for k, v in self.count.items() if v <= max}
if max_features is not None:
# [(k,v),(k,v)....] --->{k:v,k:v}
self.count = dict(sorted(self.count.items(), key = lambda x : x[1], reverse = True)[: max_features])
# 构建词表
for word in self.count:
self.dict[word] = len(self.dict)
self.inverse_dict = dict(zip(self.dict.values(), self.dict.keys()))
def transfrom(self, sentence, max_len = None, add_eos = False):
"""
文本转序列
:param sentence: 分词后的句子
:param max_len: 句子最大长度
:param add_eos: 是否添加结束标记
:return:
"""
if max_len and add_eos:
max_len = max_len -1
# 句子过长进行裁剪
if max_len <= len(sentence):
sentence = sentence[:max_len]
# 句子过短进行填充
if max_len > len(sentence):
sentence += [self.PAD_TAG] * (max_len - len(sentence))
# 若添加结束标记
if add_eos:
# 在pad标记前添加EOS
if self.PAD_TAG in sentence:
index = sentence.index(self.PAD_TAG)
sentence.insert(index, self.EOS_TAG)
# 无pad的情况下,直接添加EOS
else:
sentence += [self.EOS_TAG]
return [self.dict.get(i, self.UNK) for i in sentence]
def inverse_transform(self, indices):
"""
序列转文本
:param indices: 序列
:return:
"""
result = []
for i in indices:
# 进行序列和文本的转化,若未知字符,采用UNK代替
temp = self.inverse_dict.get(i, self.UNK_TAG)
# 判断是否遇到结束标记EOS,若是结束添加
if i != self.EOS_TAG:
result.append(temp)
else:
break
# 将转换好的文字进行拼接为一句话
return ''.join(result)
def __len__(self):
return len(self.dict)
if __name__ == '__main__':
sentences = [["今天","天气","很","好"],
["今天","去","吃","什么"]]
ws = Word2Sequence()
for sentence in sentences:
ws.fit(sentence)
ws.build_vocab(min = 1)
print('vocab_dict', ws.dict)
ret = ws.transfrom(["好","好","好","好","好","好","好","热","呀"],max_len=13, add_eos=True)
print('transfrom',ret)
ret = ws.inverse_transform(ret)
print('inverse',ret)
print(ws.PAD_TAG)
print(ws.PAD)
gen_ws.py
import pickle
from tqdm import tqdm
import random
import config
from word2sequence import Word2Sequence
def chatbot_data_split():
# 对数据集进行切分
input_ = open('./corpus/input.txt', encoding = 'utf-8').readlines()
target = open('./corpus/target.txt', encoding = 'utf-8').readlines()
# 训练集
f_train_input = open('./corpus/train_input.txt', 'a', encoding = 'utf-8')
f_train_target = open('./corpus/train_target.txt', 'a', encoding = 'utf-8')
# 测试集
f_test_input = open('./corpus/test_input.txt', 'a', encoding = 'utf-8')
f_test_target = open('./corpus/test_target.txt', 'a', encoding = 'utf-8')
# 从input_ 和 target中每次各取一条数据,分别按8:2写入训练集和测试集
for input_, target in tqdm(zip(input_, target), desc='spliting'):
if random.random() > 0.2:
f_train_input.write(input_)
f_train_target.write(target)
else:
f_test_input.write(input_)
f_test_target.write(target)
f_train_input.close()
f_train_target.close()
f_test_input.close()
f_test_target.close()
def gen_ws(train_path, test_path, save_path):
"""
生成词表
:param train_path: 训练数据
:param test_path: 测试数据
:param save_path: 保存路径
:return:
"""
ws = Word2Sequence()
for line in tqdm(open(train_path, encoding = 'utf-8').readlines(), desc = 'build_vocab1..'):
ws.fit(line.strip().split())
for line in tqdm(open(test_path, encoding= 'utf-8').readlines(), desc = 'build_vocab2..'):
ws.fit(line.strip().split())
ws.build_vocab(min = 5, max = None, max_features = 5000)
print(len(ws))
pickle.dump(ws, open(save_path, 'wb'))
if __name__ == '__main__':
chatbot_data_split()
train_input_path = './corpus/train_input.txt'
test_input_path = './corpus/test_input.txt'
train_target_path = './corpus/train_target.txt'
test_target_path = './corpus/test_target.txt'
input_ws_path = './model/ws_input.pkl'
target_ws_path = './model/ws_target.pkl'
gen_ws(train_input_path, test_input_path, input_ws_path)
gen_ws(train_target_path, test_target_path, target_ws_path)
dataset.py
import random
from tqdm import tqdm
import config
import torch
from torch.utils.data import Dataset, DataLoader
class ChatbotDataset(Dataset):
def __init__(self, train = True):
"""
:param train: 指定生成训练数据还是测试数据
"""
input_path = './corpus/train_input.txt' if train else './corpus/test_input.txt'
target_path = './corpus/train_target.txt' if train else './corpus/test_target.txt'
self.input_data = open(input_path, encoding = 'utf-8').readlines()
self.target_data = open(target_path, encoding = 'utf-8').readlines()
# 由于闲聊模型,因此输入必须对应一条输出
assert len(self.input_data) == len(self.target_data), '输入输出长度不一致!'
def __getitem__(self, idx):
input = self.input_data[idx].strip().split()
target = self.target_data[idx].strip().split()
# 获取真实长度
input_length = len(input) if len(input) < config.chatbot_input_max_len else config.chatbot_input_max_len
target_length = len(target) if len(target) < config.chatbot_target_max_len else config.chatbot_target_max_len
# 对文本进行序列转换
input = config.input_ws.transfrom(input, max_len = config.chatbot_input_max_len)
target = config.target_ws.transfrom(target, max_len = config.chatbot_target_max_len)
return input, target, input_length, target_length
def __len__(self):
return len(self.input_data)
def get_dataloader(train = True):
# 获取训练集和测试集的dataloader
batch_size = config.chatbot_train_batch_size if train else config.chatbot_test_batch_size
return DataLoader(ChatbotDataset(train), batch_size = batch_size, shuffle = True, collate_fn = collate_fn)
def collate_fn(batch):
# 需要对每个batch按长度进行排序
batch = sorted(batch, key = lambda x : x[2], reverse = True)
input, target, input_length, target_length = zip(*batch)
# 封装为tensor
input_tensor = torch.LongTensor(input)
target_tensor = torch.LongTensor(target)
input_length = torch.LongTensor(input_length)
target_length = torch.LongTensor(target_length)
return input_tensor, target_tensor, input_length, target_length
if __name__ == '__main__':
train_data_loader = get_dataloader(train = False)
for idx, (input, target, input_length, target_length) in enumerate(train_data_loader):
print(input)
print(input.size()) # [batch_size, seq_len]
print(target)
print(target.size()) # [batch_size, seq_len]
print(input_length)
print(input_length.size()) # [batch_size]
print(target_length)
print(target_length.size()) # [batch_size]
break
print(config.target_ws.dict)
print(len(config.input_ws))
print(len(config.target_ws))
encoder.py
import torch.nn as nn
import config
import torch
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.vocab_size = len(config.input_ws)
self.embedding = nn.Embedding(
num_embeddings = self.vocab_size,
embedding_dim = config.chatbot_encoder_embedding_dim
)
self.gru = nn.GRU(
input_size = config.chatbot_encoder_embedding_dim,
hidden_size = config.chatbot_encoder_hidden_size,
num_layers = config.chatbot_encoder_numlayers,
bidirectional = config.chatbot_encoder_bidirectional,
dropout = config.chatbot_encoder_dropout,
batch_first = True
)
def forward(self, input, input_length):
# 经过embedding input[batch_size, seq_len] ->[batch_size, seq_len, embedding_dim]
input_embeded = self.embedding(input)
# 进行打包 input_packed [batch_size, seq_len, embedding_dim]
input_packed = nn.utils.rnn.pack_padded_sequence(input = input_embeded, lengths = input_length, batch_first = True)
# 通过GRU output ->[batch_size, seq_len, hidden_size * 2]
# hidden ->[numlayer * 2, batch_size, hidden_size]
output, hidden = self.gru(input_packed)
# 进行解包 encoder_output ->[batch_size, seq_len, hidden_size]
encoder_output, output_length = nn.utils.rnn.pad_packed_sequence(sequence = output, batch_first = True, padding_value = config.input_ws.PAD)
# 由于双向GRU 对hidden进行拼接
# hidden[num_layer * 2, batch_size, hidden_size]
# hidden[-1] == hidden[-1, :, :]
# 经过torch.cat((hidden[-1, :, :], hidden[-2, :, :]), dim = -1)
# encoder_hidden = [batch_size, hidden_size * 2]
# 由于decoder 输入要求是三维,对encoder_hidden扩展维度为[1, batch_size, hidden_size * 2]
encoder_hidden = torch.cat((hidden[-1, :, :], hidden[-2, :, :]), dim = -1).unsqueeze(0)
return encoder_output, encoder_hidden
decoder.py
import torch.nn as nn
import torch
import config
import torch.nn.functional as F
import random
import numpy as np
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.vocab_size = len(config.target_ws)
self.hidden_size = config.chatbot_encoder_hidden_size * 2
self.embedding = nn.Embedding(
num_embeddings = len(config.target_ws),
embedding_dim = config.chatbot_decoder_embedding_dim
)
self.gru = nn.GRU(
input_size = config.chatbot_decoder_embedding_dim,
# encoder中采用双向GRU, 在最后进行了双向拼接,decoder中hidden为encoder_hidden * 2
# 以下注释中hidden_size, 均为decoder中hidden_size
hidden_size = self.hidden_size,
num_layers = config.chatbot_decoder_numlayers,
batch_first = True,
bidirectional = False
)
self.fc = nn.Linear(self.hidden_size ,self.vocab_size)
def forward(self, encoder_hidden, target):
"""
:param encoder_hidden: [1, batch_size, hidden_size]
:param target: [batch_size, seq_len]
:return:
"""
batch_size = encoder_hidden.size(1)
# 初始化一个[batch_size, 1]的全SOS张量,作为decoder的第一个time step输入
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
# encoder_hidden[1, batch_size, hidden_size] 作为decoder 第一个time step 的输入
decoder_hidden = encoder_hidden
# 初始化一个[batch_size, seq_len, vocab_size]的outputs 存储每个时间步结果
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, self.vocab_size]).to(config.device)
# 判断是否使用teacher_forcing
if random.random() > config.teacher_forcing:
# 进行每个时间步的遍历
for t in range(config.chatbot_target_max_len):
output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
# 获取每个时间步t的结果
decoder_outputs[:, t, :] = output_t
# 若不使用teacher_forcing 选取每个时间步的预测值最为下个时间步输入
# index[batch_size, 1]
"""
若使用max()
value, index = output_t.max(dim = -1) # [batch_size]
decoder_input = index.unsqueeze(1)
若是argmax()
index = output_t.argmax(dim = -1) # [batch_size]
decoder_input = index.unsqueeze(1)
"""
value, index = torch.topk(output_t, k = 1)
# 需要保证decoder_input的输入为[batch_size, 1]
decoder_input = index
else:
for t in range(config.chatbot_target_max_len):
output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
decoder_outputs[:, t, :] = output_t
# 若使用teacher_forcing,采用真实值作为下次输入
# 使得decoder_input的形状为[batch_size, 1]
decoder_input = target[:,t].unsqueeze(1)
return decoder_outputs, decoder_hidden
def forward_step(self, decoder_input, decoder_hidden):
"""
处理每个时间步逻辑
:param decoder_input: [batch_size, 1]
:param decoder_hidden: [1, batch_size, hidden_size]
:return:
"""
# [batch_size, 1] ->[batch_size, 1, embedding_dim]
decoder_input_embeded = self.embedding(decoder_input)
# decoder_output_t [batch_size, 1, embedding_dim] ->[batch_size, 1, hidden_size]
# decoder_hidden_t [1,batch_size, hidden_size]
decoder_output_t, decoder_hidden_t = self.gru(decoder_input_embeded, decoder_hidden)
# 对decoder_output_t进行fcq前,需要对其进行形状改变 [batch_size, hidden_size]
decoder_output_t = decoder_output_t.squeeze(1)
# 进行fc -> [batch_size, vocab_size]
decoder_output_t = F.log_softmax(self.fc(decoder_output_t), dim = -1)
return decoder_output_t, decoder_hidden_t
def evaluate(self, encoder_hidden):
"""
评估, 和fowward逻辑类似
:param encoder_hidden: encoder最后time step的隐藏状态 [1, batch_size, hidden_size]
:return:
"""
batch_size = encoder_hidden.size(1)
# 初始化一个[batch_size, 1]的SOS张量,作为第一个time step的输出
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
# encoder_hidden 作为decoder第一个时间步的hidden [1, batch_size, hidden_size]
decoder_hidden = encoder_hidden
# 初始化[batch_size, seq_len, vocab_size]的outputs 拼接每个time step结果
decoder_outputs = torch.zeros((batch_size, config.chatbot_target_max_len, self.vocab_size)).to(config.device)
# 初始化一个空列表,存储每次的预测序列
predict_result = []
# 对每个时间步进行更新
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
# 拼接每个time step,decoder_output_t [batch_size, vocab_size]
decoder_outputs[:, t, :] = decoder_output_t
# 由于是评估,需要每次都获取预测值
index = torch.argmax(decoder_output_t, dim = -1)
# 更新下一时间步的输入
decoder_input = index.unsqueeze(1)
# 存储每个时间步的预测序列
predict_result.append(index.cpu().detach().numpy()) # [[batch], [batch]...] ->[seq_len, vocab_size]
# 结果转换为ndarry,每行是一个预测结果即单个字对应的索引, 所有行为seq_len长度
predict_result = np.array(predict_result).transpose() # (batch_size, seq_len)的array
return decoder_outputs, predict_result
seq2seq.py
import torch.nn as nn
from encoder import Encoder
from decoder import Decoder
class Seq2SeqModel(nn.Module):
def __init__(self):
super(Seq2SeqModel, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, input, target, input_length, target_length):
encoder_outputs, encoder_hidden = self.encoder(input, input_length)
decoder_outputs, decoder_hidden = self.decoder(encoder_hidden, target)
return decoder_outputs
def evaluation(self, input, input_length):
encoder_outputs, encoder_hidden = self.encoder(input, input_length)
decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden)
return decoder_outputs, predict_result
train.py
import torch
from Seq2Seq import Seq2SeqModel
import torch.optim as optim
import config
from dataset import get_dataloader
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np
import os
model = Seq2SeqModel().to(config.device)
if os.path.exists('./model/chatbot_model.pkl'):
model.load_state_dict(torch.load('./model/chatbot_model.pkl'))
optimizer = optim.Adam(model.parameters())
loss_list = []
def train(epoch):
train_dataloader = get_dataloader(train = True)
bar = tqdm(train_dataloader, desc = 'training', total = len(train_dataloader))
model.train()
for idx, (input, target, input_length, target_length) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_length = input_length.to(config.device)
target_length = target_length.to(config.device)
optimizer.zero_grad()
decoder_outputs = model(input, target, input_length, target_length)
# 由于在decoder中进行log_softmax计算,计算损失需要F.nll_loss
# decoder_outputs [batch_size, seq_len, vocab_size]
# target [batch_size, seq_len]
loss = F.nll_loss(decoder_outputs.view(-1, len(config.target_ws)), target.view(-1), ignore_index = config.target_ws.PAD)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
bar.set_description('epoch:{}, idx{}/{}, loss:{:.6f}'.format(epoch + 1, idx, len(train_dataloader), np.mean(loss_list)))
if idx % 100 == 0:
torch.save(model.state_dict(), './model/chatbot_model.pkl')
if __name__ == '__main__':
for i in range(100):
train(i)
eval.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataset import get_dataloader
import config
import numpy as np
from Seq2Seq import Seq2SeqModel
import os
from tqdm import tqdm
model = Seq2SeqModel().to(config.device)
if os.path.exists('./model/chatbot_model.pkl'):
model.load_state_dict(torch.load('./model/chatbot_model.pkl'))
def eval():
model.eval()
loss_list = []
test_data_loader = get_dataloader(train = False)
with torch.no_grad():
bar = tqdm(test_data_loader, desc = 'testing', total = len(test_data_loader))
for idx, (input, target, input_length, target_length) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_length = input_length.to(config.device)
target_length = target_length.to(config.device)
# 获取模型的预测结果
decoder_outputs, predict_result = model.evaluation(input, input_length)
# 计算损失
loss = F.nll_loss(decoder_outputs.view(-1, len(config.target_ws)), target.view(-1), ignore_index = config.target_ws.PAD)
loss_list.append(loss.item())
bar.set_description('idx{}:/{}, loss:{:.6f}'.format(idx, len(test_data_loader), np.mean(loss_list)))
if __name__ == '__main__':
eval()
闲聊的使用 interface.py
from cut_sentence import cut
import torch
import config
from Seq2Seq import Seq2SeqModel
import os
# 模拟聊天场景,对用户输入进来的话进行回答
def interface():
# 加载训练集好的模型
model = Seq2SeqModel().to(config.device)
assert os.path.exists('./model/chatbot_model.pkl') , '请先在train中对模型进行训练!'
model.load_state_dict(torch.load('./model/chatbot_model.pkl'))
model.eval()
while True:
# 输入进来的原始字符串,进行分词处理
input_string = input('me>>:')
if input_string == 'q':
print('下次再聊')
break
input_cuted = cut(input_string, by_word = True)
# 进行序列转换和tensor封装
input_tensor = torch.LongTensor([config.input_ws.transfrom(input_cuted, max_len = config.chatbot_input_max_len)]).to(config.device)
input_length_tensor = torch.LongTensor([len(input_cuted)]).to(config.device)
# 获取预测结果
outputs, predict = model.evaluation(input_tensor, input_length_tensor)
# 进行序列转换文本
result = config.target_ws.inverse_transform(predict[0])
print('chatbot>>:', result)
if __name__ == '__main__':
interface()
attention.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import config
class Attention(nn.Module):
def __init__(self, method):
"""
attention 机制
:param method:三种attention_weights 计算方法general, dot, concat
"""
super(Attention, self).__init__()
self.method = method
self.hidden_size = config.chatbot_encoder_hidden_size
assert self.method in ['dot', 'general', 'concat'], 'attention method error'
if self.method == 'dot':
# dot 为decoder_hidden 和encoder_outputs 直接进行矩阵乘法
pass
elif self.method == 'general':
# general为对decoder_hidden 进行矩阵变换后,与encoder_outputs相乘
self.Wa = nn.Linear(config.chatbot_encoder_hidden_size * 2, config.chatbot_encoder_hidden_size * 2,
bias=False)
elif self.method == 'concat':
self.Wa = nn.Linear(config.chatbot_encoder_hidden_size * 4, config.chatbot_encoder_hidden_size * 2,
bias=False)
self.Va = nn.Linear(config.chatbot_encoder_hidden_size * 2, 1, bias = False)
def forward(self, decoder_hidden, encoder_outputs):
"""
进行三种运算得到attn_weights
:param decoder_hidden: decoder每个时间步的隐藏状态[1, batch_size, en_hidden_size * 2]
由于encoder中使用Bi-GRU,最后对双向hidden进行了拼接,因此de_hidden_size = en_hidden_size * 2
未拼接前 encoder_hidden [1, batch_size, en_hidden_size]
:param encoder_outputs:encoder最后的输出[batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
if self.method == 'dot':
return self.dot_score(decoder_hidden, encoder_outputs)
elif self.method == 'general':
return self.general_score(decoder_hidden, encoder_outputs)
elif self.method == 'concat':
return self.concat_score(decoder_hidden, encoder_outputs)
def dot_score(self, decoder_hidden, encoder_outputs):
"""
dot 方法:直接对decoder_hidden 和 encoder_outputs进行矩阵乘法
:param decoder_hidden: [1, batch_size, en_hidden_size * 2]
:param encoder_outputs:[batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
# 要进行矩阵乘法,需要改变decoder_hidden的形状为[batch_size, en_hidde_size * 2 , 1]
# 乘法后形状为[batch_size, en_seq_len, 1]
# squeeze去掉1的维度 为[batch_size, en_seq_len]
# 最终对结果在en_seq_len维度上进行log_softmax
return F.log_softmax(torch.bmm(encoder_outputs, decoder_hidden.permute(1, 2, 0)).squeeze(-1), dim = -1)
def general_score(self, decoder_hidden, encoder_outputs):
"""
general 方法:对decoder_hidden进行线性变换后与encoder_outputs进行矩阵乘法
:param decoder_hidden: [1, batch_size, en_hidden_size * 2]
:param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
# 由于要进行线性变换, decoder_hidden首先变成二维张量,因此线性变换的输入维度为en_hidden_size * 2
# [1, batch_size, en_hidden_size * 2]->[batch_size, en_hidden_size * 2]
decoder_hidden = decoder_hidden.squeeze(0)
# 由于要与encoder_outputs进行矩阵计算,需要将decoder_hidden的形状改变为dot中的形状
# 即[batch_size, en_hidden_size * 2, 1],因此线性变换的输出维度为en_hidden_size * 2
decoder_hidden = self.Wa(decoder_hidden).unsqueeze(-1)
# 进行矩阵乘法[batch_size, en_seq_len, 1] ->squeeze [batch_size, en_seq_len]
# torch.bmm 注意矩阵形状, 参数位置需要 根据矩阵乘法要求,不能写反
return F.log_softmax(torch.bmm(encoder_outputs, decoder_hidden).squeeze(-1), dim = -1)
def concat_score(self, decoder_hidden, encoder_outputs):
"""
concat方法:decoder_hidden和encoder_outputs拼接,
把这个结果使用tanh进行处理后的结果进行对齐(进行矩阵乘法,变换为需要的形状)计算之后,
和encoder outputs进行矩阵乘法
:param decoder_hidden: [1, batch_size, en_hidden_size * 2]
:param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
encoder_seq_len = encoder_outputs.size(1)
batch_size = encoder_outputs.size(0)
# repeat 沿着该维度重复指定次数
# repeat(3,1,1)指在0维度重复3次,其他2个维度各一次
# decoder_hidden [1, batch_size, en_hidden_size *2]->squeeze(0):[batch_size, en_hidden_size * 2]
# ->repeat:[encoder_seq_len, batch_size, en_hidden_size * 2] ->transpose:[batch_size, encoder_seq_len, en_hidden_size * 2]
decoder_hidden_repeated = decoder_hidden.squeeze(0).repeat(encoder_seq_len, 1, 1).transpose(1,0)
# 对decoder_hidden_repeated和encoder_outputs进行拼接
# cat:[batch_size, en_seq_len, en_hidden_size * 2 *2]
# view[batch_size * en_seq_len, en_hidden_size * 4]
# 因此第一个线性层输入维度为en_hidden_size * 4
h_cated = torch.cat((decoder_hidden_repeated, encoder_outputs), dim = -1).view(batch_size * encoder_seq_len, -1)
# 拼接后,需要进行线性变换及tanh和第二次线性变换最终将结果变为[batch_size, en_seq_len]
# h_cated->Wa:[batch_size * en_seq_len, en_hidden_size *4] ->[batch_size * en_seq_len, en_hidden_size *2]
# ->Va:[batch_size * en_seq_len, en_hidden_size *2] ->[batch_size * en_seq_len, 1]
# ->view:[batch_size * en_seq_len, 1] ->[batch_size ,en_seq_len]
attn_weight = self.Va(torch.tanh(self.Wa(h_cated))).view([batch_size, encoder_seq_len])
return F.log_softmax(attn_weight, dim = -1)
修改decoder.py
import torch
import torch.nn as nn
import config
from attention import Attention
import random
import torch.nn.functional as F
import numpy as np
class Decoder_Attn(nn.Module):
def __init__(self):
super(Decoder_Attn, self).__init__()
self.vocab_size = len(config.target_ws)
# encoder中为双向GRU,hidden进行了双向拼接,为了attention计算方便
# hidden_size = en_hidden_size
self.hidden_size = config.chatbot_encoder_hidden_size
self.embedding = nn.Embedding(
num_embeddings = self.vocab_size,
embedding_dim = config.chatbot_decoder_embedding_dim
)
self.gru = nn.GRU(
input_size = config.chatbot_decoder_embedding_dim,
hidden_size = self.hidden_size * 2,
num_layers = config.chatbot_decoder_numlayers,
batch_first = True,
bidirectional = False
)
# 处理forward_step中decoder的每个时间步输出形状
self.fc = nn.Linear(self.hidden_size * 2, self.vocab_size)
# 实例化attn_weights
self.attn = Attention(method = 'general')
# self.attn 形状为
self.fc_attn = nn.Linear(self.hidden_size * 4, self.hidden_size * 2)
def forward(self, encoder_hidden, target, encoder_outputs):
"""
:param encoder_hidden: [1, batch_size, en_hidden_size * 2]
:param target: [batch_size, en_seq_len]
添加attention的decoder中, 对encoder_outputs进行利用与decoder_hidden计算得到注意力表示
encoder_outputs为新增参数
:param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
batch_size = encoder_hidden.size(1)
# 初始化一个[batch_size, 1]的SOS作为decoder第一个时间步的输入decoder_input
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
# 初始化一个[batch_size, de_seq_len, vocab_size]的张量拼接每个time step结果
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, self.vocab_size]).to(config.device)
# encoder_hidden 作为decoder的第一个时间步的hidden
decoder_hidden = encoder_hidden
# 按照teacher_forcing的更新策略进行每个时间步的更新
teacher_forcing = random.random() > 0.5
if teacher_forcing:
# 对每个时间步进行遍历
for t in range(config.chatbot_target_max_len):
# decoder_output_t [batch_size, vocab_size]
# decoder_hidden [1, batch_size, en_hidden_size * 2]
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
# 进行拼接每个时间步
decoder_outputs[:, t, :] = decoder_output_t
# 使用teacher_forcing,下一次采用真实结果
# target[:, t] [batch_size]
# decoder_input [batch_size, 1]
decoder_input = target[:, t].unsqueeze(1)
else:
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
# 不使用teacher_forcing下次采用预测结果
index = torch.argmax(decoder_output_t, dim = -1)
# index [batch_size]
# decoder_input [batch_size, 1]
decoder_input = index.unsqueeze(-1)
return decoder_outputs, decoder_hidden
def forward_step(self, decoder_input, decoder_hidden, encoder_outputs):
"""
每个时间步的处理
:param decoder_input: [batch_size, 1]
:param decoder_hidden: [1, batch_size, en_hidden_size * 2]
:param encoder_outputs [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
# 依次通过embedding、gru 和fc 最终返回log_softmax
# decoder_input_embeded [batch_size, 1] -> [batch_size, 1, embedding_dim]
decoder_input_embeded = self.embedding(decoder_input)
decoder_output, decoder_hidden = self.gru(decoder_input_embeded, decoder_hidden)
"""通过decoder_hidden和encoder_outputs进行attention计算"""
# 1.通过attention 计算attention weight
# decoder_hidden [1, batch_size, en_hidden_size * 2]
# encoder_outputs [batch_size, en_seq_len, en_hidden_size * 2]
# attn_weight [batch_size, en_seq_len]
attn_weight = self.attn(decoder_hidden, encoder_outputs)
# 2.attn_weight与encoder_outputs 计算得到上下文向量
# encoder_outputs[batch_size, en_seq_len, en_hidden_size * 2]
# attn_weight [batch_size, en_seq_len]
# 二者进行矩阵乘法,需要对attn_weight进行维度扩展->[batch_size, 1, en_seq_len]
# context_vector [batch_size, 1, en_hidden_size * 2]
context_vector = torch.bmm(attn_weight.unsqueeze(1), encoder_outputs)
# 3.context_vector 与decoder每个时间步输出decoder_output进行拼接和线性变换,得到每个时间步的注意力结果输出
# decoder_output[batch_size, 1, en_hidden_size * 2]
# context_vector[batch_size, 1, en_hidden_size * 2]
# 拼接后形状为 [batch_size, 1, en_hidden_size * 4]
# 由于要进行全连接操作,需要对拼接后形状进行降维unsqueeze(1)
# ->[batch_size, en_hidden_size * 4]
# 且decoder每个时间步输出结果经过self.fc的维度为[batch_size, vocab_size]
# 因此,self.attn的fc 输入输出维度为(en_hidden_size * 4, en_hidden_size * 2)
# self.fc输入输出维度为(en_hidden_size * 2, vocab_size)
# 注意:这里用torch.tanh,使用F.tanh会 'nn.functional.tanh is deprecated. Use torch.tanh instead'
attn_result = torch.tanh(self.fc_attn(torch.cat((decoder_output, context_vector),dim = -1).squeeze(1)))
# attn_result [batch_size, en_hidden_size * 2]
# 经过self.fc后改变维度
# decoder_output_t [batch_size, en_hidden_size * 2]->[batch_size, vocab_size]
decoder_output_t = F.log_softmax(self.fc(attn_result), dim = -1)
# decoder_hiddem [1, batch_size, en_hidden_size * 2]
return decoder_output_t, decoder_hidden
def evaluate(self, encoder_hidden, encoder_outputs):
"""
评估逻辑
:param encoder_hidden: [1, batch_size, en_hidden_size * 2]
:param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
batch_size = encoder_hidden.size(1)
# 初始化decoder第一个时间步的输入和hidden和decoder输出
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
decoder_outputs = torch.zeros((batch_size, config.chatbot_target_max_len, len(config.target_ws))).to(config.device)
decoder_hidden = encoder_hidden
# 初始化用于存储的预测序列
predict_result = []
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
index = torch.argmax(decoder_output_t, dim = -1)
decoder_input = index.unsqueeze(-1)
predict_result.append(index.cpu().detach().numpy())
predict_result = np.array(predict_result).transpose()
return decoder_outputs, predict_result
修改seq2seq.py
from encoder import Encoder
from decoder_attn import Decoder_Attn
import torch.nn as nn
class Seq2Seq_Attn(nn.Module):
def __init__(self):
super(Seq2Seq_Attn, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder_Attn()
def forward(self, input, target, input_length, target_length):
encoder_output, encoder_hidden = self.encoder(input, input_length)
decoder_output, decoder_hidden = self.decoder(encoder_hidden, target, encoder_output)
return decoder_output
def evaluation_attn(self, input, input_length):
encoder_output, encoder_hidden = self.encoder(input, input_length)
decoder_output, predict_result = self.decoder.evaluate(encoder_hidden, encoder_output)
return decoder_output, predict_result
train.py
from Seq2Seq_attn import Seq2Seq_Attn
import torch
import torch.optim as optim
import os
from tqdm import tqdm
from dataset import get_dataloader
import config
import torch.nn.functional as F
model = Seq2Seq_Attn().to(config.device)
if os.path.exists('./model/chatbot_attn_model.pkl'):
model.load_state_dict(torch.load('./model/chatbot_attn_model.pkl'))
optimizer = optim.Adam(model.parameters())
loss_list = []
def train_attn(epoch):
model.train()
train_dataloader = get_dataloader(train = True)
bar = tqdm(train_dataloader, desc = 'attn_training...', total = len(train_dataloader))
for idx, (input, target, input_length, target_length) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_length = input_length.to(config.device)
target_length = target_length.to(config.device)
optimizer.zero_grad()
outputs = model(input, target, input_length, target_length)
# outputs [batch_size, de_seq_len, vocab_size]
# target [batch_size, de_seq_len]
loss = F.nll_loss(outputs.view(-1, len(config.target_ws)), target.view(-1), ignore_index = config.target_ws.PAD)
loss.backward()
optimizer.step()
loss_list.append(loss.item())
bar.set_description('epoch:{},idx:{}/{},loss{:.6f}'.format(epoch + 1, idx, len(train_dataloader), loss.item()))
if idx % 100 == 0:
torch.save(model.state_dict(), './model/chatbot_attn_model.pkl')
if __name__ == '__main__':
for i in range(10):
train_attn(i)
eval.py
from Seq2Seq_attn import Seq2Seq_Attn
import torch
import config
import numpy as np
from dataset import get_dataloader
import torch.nn.functional as F
from tqdm import tqdm
model = Seq2Seq_Attn().to(config.device)
model.load_state_dict(torch.load('./model/chatbot_attn_model.pkl'))
loss_list = []
def eval():
model.eval()
test_dataloader = get_dataloader(train = False)
bar = tqdm(test_dataloader, desc = 'attn_test...', total = len(test_dataloader))
for idx, (input, target, input_length, target_length) in enumerate(bar):
input = input.to(config.device)
target = target.to(config.device)
input_length = input_length.to(config.device)
target_length = target_length.to(config.device)
with torch.no_grad():
output, predict_result = model.evaluation_attn(input, input_length)
loss = F.nll_loss(output.view(-1, len(config.target_ws)), target.view(-1), ignore_index = config.target_ws.PAD)
loss_list.append(loss.item())
bar.set_description('idx:{}/{}, loss:{}'.format(idx, len(test_dataloader), np.mean(loss_list)))
if __name__ == '__main__':
eval()
interface.py
from cut_sentence import cut
import torch
import config
from Seq2Seq_attn import Seq2Seq_Attn
import os
# 模拟聊天场景,对用户输入进来的话进行回答
def interface():
# 加载训练集好的模型
model = Seq2Seq_Attn().to(config.device)
assert os.path.exists('./model/chatbot_attn_model.pkl') , '请先在train中对模型进行训练!'
model.load_state_dict(torch.load('./model/chatbot_attn_model.pkl'))
model.eval()
while True:
# 输入进来的原始字符串,进行分词处理
input_string = input('me>>:')
if input_string == 'q':
print('下次再聊')
break
input_cuted = cut(input_string, by_word = True)
# 进行序列转换和tensor封装
input_tensor = torch.LongTensor([config.input_ws.transfrom(input_cuted, max_len = config.chatbot_input_max_len)]).to(config.device)
input_length_tensor = torch.LongTensor([len(input_cuted)]).to(config.device)
# 获取预测结果
outputs, predict = model.evaluation_attn(input_tensor, input_length_tensor)
# 进行序列转换文本
result = config.target_ws.inverse_transform(predict[0])
print('chatbot>>:', result)
if __name__ == '__main__':
interface()
修改decoder.py
import torch
import torch.nn as nn
import config
from attention import Attention
import random
import torch.nn.functional as F
import numpy as np
import heapq
class Decoder_Attn_Beam(nn.Module):
def __init__(self):
super(Decoder_Attn_Beam, self).__init__()
self.vocab_size = len(config.target_ws)
# encoder中为双向GRU,hidden进行了双向拼接,为了attention计算方便
# hidden_size = en_hidden_size
self.hidden_size = config.chatbot_encoder_hidden_size
self.embedding = nn.Embedding(
num_embeddings = self.vocab_size,
embedding_dim = config.chatbot_decoder_embedding_dim
)
self.gru = nn.GRU(
input_size = config.chatbot_decoder_embedding_dim,
hidden_size = self.hidden_size * 2,
num_layers = config.chatbot_decoder_numlayers,
batch_first = True,
bidirectional = False
)
# 处理forward_step中decoder的每个时间步输出形状
self.fc = nn.Linear(self.hidden_size * 2, self.vocab_size)
# 实例化attn_weights
self.attn = Attention(method = 'general')
# self.attn 形状为
self.fc_attn = nn.Linear(self.hidden_size * 4, self.hidden_size * 2)
def forward(self, encoder_hidden, target, encoder_outputs):
"""
:param encoder_hidden: [1, batch_size, en_hidden_size * 2]
:param target: [batch_size, en_seq_len]
添加attention的decoder中, 对encoder_outputs进行利用与decoder_hidden计算得到注意力表示
encoder_outputs为新增参数
:param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
batch_size = encoder_hidden.size(1)
# 初始化一个[batch_size, 1]的SOS作为decoder第一个时间步的输入decoder_input
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
# 初始化一个[batch_size, de_seq_len, vocab_size]的张量拼接每个time step结果
decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, self.vocab_size]).to(config.device)
# encoder_hidden 作为decoder的第一个时间步的hidden
decoder_hidden = encoder_hidden
# 按照teacher_forcing的更新策略进行每个时间步的更新
teacher_forcing = random.random() > 0.5
if teacher_forcing:
# 对每个时间步进行遍历
for t in range(config.chatbot_target_max_len):
# decoder_output_t [batch_size, vocab_size]
# decoder_hidden [1, batch_size, en_hidden_size * 2]
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
# 进行拼接每个时间步
decoder_outputs[:, t, :] = decoder_output_t
# 使用teacher_forcing,下一次采用真实结果
# target[:, t] [batch_size]
# decoder_input [batch_size, 1]
decoder_input = target[:, t].unsqueeze(1)
else:
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
# 不使用teacher_forcing下次采用预测结果
index = torch.argmax(decoder_output_t, dim = -1)
# index [batch_size]
# decoder_input [batch_size, 1]
decoder_input = index.unsqueeze(-1)
return decoder_outputs, decoder_hidden
def forward_step(self, decoder_input, decoder_hidden, encoder_outputs):
"""
每个时间步的处理
:param decoder_input: [batch_size, 1]
:param decoder_hidden: [1, batch_size, en_hidden_size * 2]
:param encoder_outputs [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
# 依次通过embedding、gru 和fc 最终返回log_softmax
# decoder_input_embeded [batch_size, 1] -> [batch_size, 1, embedding_dim]
decoder_input_embeded = self.embedding(decoder_input)
decoder_output, decoder_hidden = self.gru(decoder_input_embeded, decoder_hidden)
"""通过decoder_hidden和encoder_outputs进行attention计算"""
# 1.通过attention 计算attention weight
# decoder_hidden [1, batch_size, en_hidden_size * 2]
# encoder_outputs [batch_size, en_seq_len, en_hidden_size * 2]
# attn_weight [batch_size, en_seq_len]
attn_weight = self.attn(decoder_hidden, encoder_outputs)
# 2.attn_weight与encoder_outputs 计算得到上下文向量
# encoder_outputs[batch_size, en_seq_len, en_hidden_size * 2]
# attn_weight [batch_size, en_seq_len]
# 二者进行矩阵乘法,需要对attn_weight进行维度扩展->[batch_size, 1, en_seq_len]
# context_vector [batch_size, 1, en_hidden_size * 2]
context_vector = torch.bmm(attn_weight.unsqueeze(1), encoder_outputs)
# 3.context_vector 与decoder每个时间步输出decoder_output进行拼接和线性变换,得到每个时间步的注意力结果输出
# decoder_output[batch_size, 1, en_hidden_size * 2]
# context_vector[batch_size, 1, en_hidden_size * 2]
# 拼接后形状为 [batch_size, 1, en_hidden_size * 4]
# 由于要进行全连接操作,需要对拼接后形状进行降维unsqueeze(1)
# ->[batch_size, en_hidden_size * 4]
# 且decoder每个时间步输出结果经过self.fc的维度为[batch_size, vocab_size]
# 因此,self.attn的fc 输入输出维度为(en_hidden_size * 4, en_hidden_size * 2)
# self.fc输入输出维度为(en_hidden_size * 2, vocab_size)
# 注意:这里用torch.tanh,使用F.tanh会 'nn.functional.tanh is deprecated. Use torch.tanh instead'
attn_result = torch.tanh(self.fc_attn(torch.cat((decoder_output, context_vector),dim = -1).squeeze(1)))
# attn_result [batch_size, en_hidden_size * 2]
# 经过self.fc后改变维度
# decoder_output_t [batch_size, en_hidden_size * 2]->[batch_size, vocab_size]
decoder_output_t = F.log_softmax(self.fc(attn_result), dim = -1)
# decoder_hiddem [1, batch_size, en_hidden_size * 2]
return decoder_output_t, decoder_hidden
def evaluate(self, encoder_hidden, encoder_outputs):
"""
评估逻辑
:param encoder_hidden: [1, batch_size, en_hidden_size * 2]
:param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
batch_size = encoder_hidden.size(1)
# 初始化decoder第一个时间步的输入和hidden和decoder输出
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
decoder_outputs = torch.zeros((batch_size, config.chatbot_target_max_len, len(config.target_ws))).to(config.device)
decoder_hidden = encoder_hidden
# 初始化用于存储的预测序列
predict_result = []
for t in range(config.chatbot_target_max_len):
decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
decoder_outputs[:, t, :] = decoder_output_t
index = torch.argmax(decoder_output_t, dim = -1)
decoder_input = index.unsqueeze(-1)
predict_result.append(index.cpu().detach().numpy())
predict_result = np.array(predict_result).transpose()
return decoder_outputs, predict_result
def evaluate_beam(self, encoder_hidden, encoder_outputs):
"""
使用beam search的评估
:param encoder_hidden: [1, batch_size, en_hidden_size * 2]
:param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
:return:
"""
# 注意:在beam search的过程中,batch_size 只能为1
batch_size = encoder_hidden.size(0)
assert batch_size == 1, 'batch_size 不为1'
# 初始化decoder的输入和输出及隐藏状态
decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
decoder_hidden = encoder_hidden
# 实例化首次beam
prev_beam = Beam()
# 第一次需要的输入数据,保存在堆中
prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden)
# 循环比较堆中前一次和后一次的数据
while True:
# 实例化当前的beam
cur_beam = Beam()
# 取出堆中的数据,判断是否遇到EOS,若是,则添加进堆中,若不是则进行forward_step
for _prob, _complete, _seq_list, _decoder_input, _decoder_hidden in prev_beam:
if _complete:
cur_beam.add(_prob, _complete, _seq_list, _decoder_input, _decoder_hidden)
else:
# decoder_output_t [1, vocab_size]
# decoder_hidden [1, 1, en_hidden_size * 2]
decoder_output_t, decoder_hidden = self.forward_step(_decoder_input, _decoder_hidden, encoder_outputs)
value, index = torch.topk(decoder_output_t, k = config.beam_width)
for val, idx in zip(value[0], index[0]):
cur_prob = _prob * val.item()
decoder_input = torch.LongTensor([[idx.item()]]).to(config.device)
cur_seq_list = _seq_list + [decoder_input]
if idx == config.target_ws.EOS:
cur_complete = True
else:
cur_complete = False
cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden)
# 获取新的堆中的优先级最高(概率最大)的数据,判断数据是否是EOS结尾或者是否达到最大长度,如果是,停止迭代
best_prob, best_complete, best_seq_list, _, _ = max(cur_beam)
if best_complete or len(best_seq_list) - 1 == config.chatbot_target_max_len:
# 对结果进行基础的处理,共后续转化为文字使用
best_seq_list = [i.item() for i in best_seq_list]
if best_seq_list[0] == config.target_ws.SOS:
best_seq_list = best_seq_list[1:]
if best_seq_list[-1] == config.target_ws.EOS:
best_seq_list = best_seq_list[:-1]
return best_seq_list
else:
# 则重新遍历新的堆中的数据
prev_beam = cur_beam
class Beam(object):
# 采用堆实现beam search
def __init__(self):
self.heapq = list() # 使用列表保存数据
self.beam_width = config.beam_width # 每次返回最大的beam_width个结果
def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
"""
添加数据,同时判断总的数据个数,多则删除
:param prob:概率乘积
:param complete:最后一个是否为EOS
:param seq_list:所有token的列表
:param decoder_input:下一次进行解码的输入,通过前一次获得
:param decoder_hidden:下一次进行解码的hidden,通过前一次获得
:return:
"""
heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
# 保证每次保存beam_width个数据
if len(self.heapq) > self.beam_width:
heapq.heappop(self.heapq)
def __iter__(self):
for item in self.heapq:
yield item
seq2seq.py
from encoder import Encoder
from decoder_beam import Decoder_Attn_Beam
import torch.nn as nn
class Seq2Seq_Attn_Beam(nn.Module):
def __init__(self):
super(Seq2Seq_Attn_Beam, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder_Attn_Beam()
def forward(self, input, target, input_length, target_length):
encoder_output, encoder_hidden = self.encoder(input, input_length)
decoder_output, decoder_hidden = self.decoder(encoder_hidden, target, encoder_output)
return decoder_output
def evaluation_attn(self, input, input_length):
encoder_output, encoder_hidden = self.encoder(input, input_length)
decoder_output, predict_result = self.decoder.evaluate(encoder_hidden, encoder_output)
return decoder_output, predict_result
def evaluation_beam(self, input, input_length):
encoder_output, encoder_hidden = self.encoder(input, input_length)
best_seq = self.decoder.evaluate_beam(encoder_hidden, encoder_output)
return best_seq
interface.py
from cut_sentence import cut
import torch
import config
from Seq2Seq_beam import Seq2Seq_Attn_Beam
import os
# 模拟聊天场景,对用户输入进来的话进行回答
def interface():
# 加载训练集好的模型
model = Seq2Seq_Attn_Beam().to(config.device)
assert os.path.exists('./model/chatbot_attn_model.pkl') , '请先在train中对模型进行训练!'
model.load_state_dict(torch.load('./model/chatbot_attn_model.pkl'))
model.eval()
while True:
# 输入进来的原始字符串,进行分词处理
input_string = input('me>>:')
if input_string == 'q':
print('下次再聊')
break
input_cuted = cut(input_string, by_word = True)
# 进行序列转换和tensor封装
input_tensor = torch.LongTensor([config.input_ws.transfrom(input_cuted, max_len = config.chatbot_input_max_len)]).to(config.device)
input_length_tensor = torch.LongTensor([len(input_cuted)]).to(config.device)
# 获取预测结果
predict = model.evaluation_beam(input_tensor, input_length_tensor)
# 进行序列转换文本
# beam_search中 返回本身为一个序列列表
result = config.target_ws.inverse_transform(predict)
print('chatbot>>:', result)
if __name__ == '__main__':
interface()