Seq2Seq聊天机器人

Seq2Seq聊天机器人

基本逻辑实现

config.py

import pickle 
import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


"""word2sequence"""

chatbot_train_batch_size = 200
chatbot_test_batch_size = 300

input_ws = pickle.load(open('./model/ws_input.pkl', 'rb'))
target_ws = pickle.load(open('./model/ws_target.pkl' ,'rb'))

chatbot_input_max_len = 20
chatbot_target_max_len = 30


"""Encoder"""
chatbot_encoder_embedding_dim = 300
chatbot_encoder_hidden_size = 128
chatbot_encoder_numlayers = 2
chatbot_encoder_bidirectional = True
# RNN中 若要添加dropout, num_layer >= 2
chatbot_encoder_dropout = 0.3

"""Decoder"""
chatbot_decoder_embedding_dim = 300
chatbot_decoder_numlayers = 1
teacher_forcing = 0.5

"""beam search"""
beam_width = 2

数据集处理

数据集采用小黄鸡语料和微博等语料进行基本处理,后采用按字切分的方式,得到input.txt,target.txt
可以参考
input.txt

呵 呵
不 是
怎 么 了
开 心 点 哈 , 一 切 都 会 好 起 来
我 还 喜 欢 她 , 怎 么 办
短 信
你 知 道 谁 么
许 兵 是 谁
这 么 假
许 兵 是 傻 逼
许 兵 是 谁
许 兵 是 谁
许 兵 是 谁
许 兵 到 底 是 谁
尼 玛 , 许 兵 到 底 是 谁
小 黄 鸭 , 你 有 女 朋 友 么
那 你 有 男 朋 友 么
那 你 在 哪

target.txt

是 王 若 猫 的 。
那 是 什 么 。
我 很 难 过 , 安 慰 我 。
嗯 , 会 的
我 帮 你 告 诉 她 , 发 短 信 还 是 打 电 话 。
嗯 嗯 , 我 也 相 信
肯 定 不 是 我 , 是 阮 德 培
吴 院 四 班 小 帅 哥
三 鹿 奶 粉 也 假 , 不 一 样 的 卖 啊
被 你 发 现 了 。
是 我 善 良 可 爱 的 主 人 的 老 公 啊
是 穆 森 的 老 婆 啊
奇 葩
不 要 凶 我 , 应 该 大 概 也 许 是 叶 祺 吧 。
我 , 我 不 会 告 诉 你 我 是 小 澜 的 。
老 娘 是 女 的 。
没 有 呢 , 我 只 要 主 人 一 个 人 疼 爱 我 嘛 。
我 无 聊

分词方法实现 cut_sentence.py

import string
import jieba
import jieba.posseg as psg
import logging


# 停用词和自定义词典路径
stopwords_path = './corpus/stopwords.txt'
keywords_path = './corpus/keywords.txt'
# 英文字母
letters = string.ascii_lowercase
# 关闭jieba的日志
jieba.setLogLevel(logging.INFO)
# 读取所有停用词到列表
stop_words = [i.strip() for i in open(stopwords_path, encoding = 'utf-8').readline()]


def cut(sentence, by_word = False, use_stopwords = False, use_seg = False):
    """
    分词方法
    :param sentence: 待分词的句子
    :param by_word: 按字切分
    :param use_stopwords: 使用停用词
    :param use_seg: 返回词性
    :return:
    """
    if by_word:
        return cut_sentence_by_word(sentence)
    else:
        return cut_sentence(sentence, use_stopwords, use_seg)


def cut_sentence(sentence, use_stopwords, use_seg):
    if use_seg:
        # 使用psg.lcut进行切分,返回[(i.word, i.flag)...]
        result = psg.lcut(sentence)
        if use_stopwords:
            result = [i for i in result if i[0] not in stop_words]
    else:
        result = jieba.lcut(sentence)
    return result


def cut_sentence_by_word(sentence):
    """
    按字进行切分
    :param sentence: 待分词的语句
    :return:
    """
    temp = '' # 用于拼接英文字符
    result = [] # 保存结果
    # 按字遍历
    for word in sentence:
        # 判断是否是英文字母
        if word in letters:
            temp += word
        # 若遇到非英文字符有两种情况
        # 1.temp = ''意味当前是个汉字,将word直接存入result中
        # 2.temp != '' 意味着拼接完了一个单词,遇到了当前汉字,需要将temp存入reslut,并置空
        else:
            if len(temp) > 0:
                result.append(temp)
                temp = ''
            result.append(word)
        # 当遍历完所有字后,最后一个字母可能存储在temp中
    if len(temp) > 0:
        result.append(temp)
    return result


if __name__ == '__main__':
    sentence = '今天天气好热a'
    res1 = cut(sentence, by_word = True)
    res2 = cut(sentence, by_word = False, use_seg = True)
    res3 = cut(sentence, by_word = False, use_seg = False)
    print(res1)
    print(res2)
    print(res3)

构建词表,实现序列和文本相互转换

word2sequence.py



class Word2Sequence(object):
    PAD_TAG = ''
    UNK_TAG = ''
    SOS_TAG = ''
    EOS_TAG = ''

    PAD = 0
    UNK = 1
    SOS = 2
    EOS = 3


    def __init__(self):
        # 词表字典
        self.dict = {
     
            self.PAD_TAG:self.PAD,
            self.UNK_TAG:self.UNK,
            self.SOS_TAG:self.SOS,
            self.EOS_TAG:self.EOS
        }
        # 统计词频用字典
        self.count = {
     }


    def fit(self, sentence):
        """
        统计每句话中的词频
        :param sentence: 经过分词后的句子
        :return:
        """
        for word in sentence:
            self.count[word] = self.count.get(word, 0) + 1


    def build_vocab(self, min = 5, max = None, max_features = None):
        """
        构建词表
        :param min: 最小词频
        :param max: 最大词频
        :param max_features: 最多特征个数
        :return:
        """
        # 注意两个条件都有等号
        if min is not None:
            self.count = {
     k : v for k, v in self.count.items() if v >= min}
        if max is not None:
            self.count = {
     k : v for k, v in self.count.items() if v <= max}
        if max_features is not None:
            # [(k,v),(k,v)....] --->{k:v,k:v}
            self.count = dict(sorted(self.count.items(), key = lambda x : x[1], reverse = True)[: max_features])
        # 构建词表
        for word in self.count:
            self.dict[word] = len(self.dict)
        self.inverse_dict = dict(zip(self.dict.values(), self.dict.keys()))


    def transfrom(self, sentence, max_len = None, add_eos = False):
        """
        文本转序列
        :param sentence: 分词后的句子
        :param max_len: 句子最大长度
        :param add_eos: 是否添加结束标记
        :return:
        """

        if max_len and add_eos:
            max_len = max_len -1
        # 句子过长进行裁剪
        if max_len <= len(sentence):
            sentence = sentence[:max_len]
        # 句子过短进行填充
        if max_len > len(sentence):
            sentence += [self.PAD_TAG] * (max_len - len(sentence))
        # 若添加结束标记
        if add_eos:
            # 在pad标记前添加EOS
            if self.PAD_TAG in sentence:
                index = sentence.index(self.PAD_TAG)
                sentence.insert(index, self.EOS_TAG)
            # 无pad的情况下,直接添加EOS
            else:
                sentence += [self.EOS_TAG]
        return [self.dict.get(i, self.UNK) for i in sentence]


    def inverse_transform(self, indices):
        """
        序列转文本
        :param indices: 序列
        :return:
        """
        result = []
        for i in indices:
            # 进行序列和文本的转化,若未知字符,采用UNK代替
            temp = self.inverse_dict.get(i, self.UNK_TAG)
            # 判断是否遇到结束标记EOS,若是结束添加
            if i != self.EOS_TAG:
                result.append(temp)
            else:
                break
        # 将转换好的文字进行拼接为一句话
        return ''.join(result)


    def __len__(self):
        return len(self.dict)

if __name__ == '__main__':
    sentences = [["今天","天气","很","好"],
                  ["今天","去","吃","什么"]]
    ws = Word2Sequence()
    for sentence in sentences:
        ws.fit(sentence)
    ws.build_vocab(min = 1)
    print('vocab_dict', ws.dict)
    ret = ws.transfrom(["好","好","好","好","好","好","好","热","呀"],max_len=13, add_eos=True)
    print('transfrom',ret)
    ret = ws.inverse_transform(ret)
    print('inverse',ret)
    print(ws.PAD_TAG)
    print(ws.PAD)

对数据集进行划分

gen_ws.py

import pickle
from tqdm import tqdm
import random
import config
from word2sequence import Word2Sequence

def chatbot_data_split():
    # 对数据集进行切分
    input_ = open('./corpus/input.txt', encoding = 'utf-8').readlines()
    target = open('./corpus/target.txt', encoding = 'utf-8').readlines()
    # 训练集
    f_train_input = open('./corpus/train_input.txt', 'a', encoding = 'utf-8')
    f_train_target = open('./corpus/train_target.txt', 'a', encoding = 'utf-8')
    # 测试集
    f_test_input = open('./corpus/test_input.txt', 'a', encoding = 'utf-8')
    f_test_target = open('./corpus/test_target.txt', 'a', encoding  = 'utf-8')
    # 从input_ 和 target中每次各取一条数据,分别按8:2写入训练集和测试集
    for input_, target in tqdm(zip(input_, target), desc='spliting'):
        if random.random() > 0.2:
            f_train_input.write(input_)
            f_train_target.write(target)
        else:
            f_test_input.write(input_)
            f_test_target.write(target)

    f_train_input.close()
    f_train_target.close()
    f_test_input.close()
    f_test_target.close()


def gen_ws(train_path, test_path, save_path):
    """
    生成词表
    :param train_path: 训练数据
    :param test_path: 测试数据
    :param save_path:  保存路径
    :return:
    """
    ws = Word2Sequence()
    for line in tqdm(open(train_path, encoding = 'utf-8').readlines(), desc = 'build_vocab1..'):
        ws.fit(line.strip().split())
    for line in tqdm(open(test_path, encoding= 'utf-8').readlines(), desc = 'build_vocab2..'):
        ws.fit(line.strip().split())

    ws.build_vocab(min = 5, max = None, max_features = 5000)
    print(len(ws))
    pickle.dump(ws, open(save_path, 'wb'))


if __name__ == '__main__':
    chatbot_data_split()
    train_input_path = './corpus/train_input.txt'
    test_input_path = './corpus/test_input.txt'
    train_target_path = './corpus/train_target.txt'
    test_target_path = './corpus/test_target.txt'
    input_ws_path = './model/ws_input.pkl'
    target_ws_path = './model/ws_target.pkl'
    gen_ws(train_input_path, test_input_path, input_ws_path)
    gen_ws(train_target_path, test_target_path, target_ws_path)

构建可迭代数据集

dataset.py

import random
from tqdm import tqdm
import config
import torch
from torch.utils.data import Dataset, DataLoader


class ChatbotDataset(Dataset):
    def __init__(self, train = True):
        """
        :param train: 指定生成训练数据还是测试数据
        """
        input_path = './corpus/train_input.txt' if train else './corpus/test_input.txt'
        target_path = './corpus/train_target.txt' if train else './corpus/test_target.txt'
        self.input_data = open(input_path, encoding = 'utf-8').readlines()
        self.target_data = open(target_path, encoding = 'utf-8').readlines()
        # 由于闲聊模型,因此输入必须对应一条输出
        assert len(self.input_data) == len(self.target_data), '输入输出长度不一致!'


    def __getitem__(self, idx):
        input = self.input_data[idx].strip().split()
        target = self.target_data[idx].strip().split()
        # 获取真实长度
        input_length = len(input) if len(input) < config.chatbot_input_max_len else config.chatbot_input_max_len
        target_length = len(target) if len(target) < config.chatbot_target_max_len else config.chatbot_target_max_len

        # 对文本进行序列转换
        input = config.input_ws.transfrom(input, max_len = config.chatbot_input_max_len)
        target = config.target_ws.transfrom(target, max_len = config.chatbot_target_max_len)
        return input, target, input_length, target_length

    
    def __len__(self):
        return len(self.input_data)


def get_dataloader(train = True):
    # 获取训练集和测试集的dataloader
    batch_size = config.chatbot_train_batch_size if train else config.chatbot_test_batch_size
    return DataLoader(ChatbotDataset(train), batch_size = batch_size, shuffle = True, collate_fn = collate_fn)


def collate_fn(batch):
    # 需要对每个batch按长度进行排序
    batch = sorted(batch, key = lambda x : x[2], reverse = True)
    input, target, input_length, target_length = zip(*batch)
    # 封装为tensor
    input_tensor = torch.LongTensor(input)
    target_tensor = torch.LongTensor(target)
    input_length = torch.LongTensor(input_length)
    target_length = torch.LongTensor(target_length)
    return input_tensor, target_tensor, input_length, target_length 


if __name__ == '__main__':
    train_data_loader = get_dataloader(train = False)
    for idx, (input, target, input_length, target_length) in enumerate(train_data_loader):
        print(input)
        print(input.size())  # [batch_size, seq_len]
        print(target)
        print(target.size())  # [batch_size, seq_len]
        print(input_length)
        print(input_length.size())  # [batch_size]
        print(target_length)
        print(target_length.size())  # [batch_size]
        break
    print(config.target_ws.dict)
    print(len(config.input_ws))
    print(len(config.target_ws))

实现Seq2Seq

encoder.py

import torch.nn as nn
import config
import torch


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.vocab_size = len(config.input_ws)
        self.embedding = nn.Embedding(
            num_embeddings = self.vocab_size,
            embedding_dim = config.chatbot_encoder_embedding_dim
        )
        self.gru = nn.GRU(
            input_size = config.chatbot_encoder_embedding_dim,
            hidden_size = config.chatbot_encoder_hidden_size,
            num_layers = config.chatbot_encoder_numlayers,
            bidirectional = config.chatbot_encoder_bidirectional,
            dropout = config.chatbot_encoder_dropout,
            batch_first = True
        )


    def forward(self, input, input_length):
        # 经过embedding input[batch_size, seq_len] ->[batch_size, seq_len, embedding_dim]
        input_embeded = self.embedding(input)
        # 进行打包 input_packed [batch_size, seq_len, embedding_dim]
        input_packed = nn.utils.rnn.pack_padded_sequence(input = input_embeded, lengths = input_length, batch_first = True)
        # 通过GRU  output ->[batch_size, seq_len, hidden_size * 2]
        # hidden ->[numlayer * 2, batch_size, hidden_size]
        output, hidden = self.gru(input_packed)
        # 进行解包 encoder_output ->[batch_size, seq_len, hidden_size]
        encoder_output, output_length = nn.utils.rnn.pad_packed_sequence(sequence = output, batch_first = True, padding_value = config.input_ws.PAD)
        # 由于双向GRU 对hidden进行拼接
        # hidden[num_layer * 2, batch_size, hidden_size]
        # hidden[-1] == hidden[-1, :, :]
        # 经过torch.cat((hidden[-1, :, :], hidden[-2, :, :]), dim = -1)
        # encoder_hidden = [batch_size, hidden_size * 2]
        # 由于decoder 输入要求是三维,对encoder_hidden扩展维度为[1, batch_size, hidden_size * 2]
        encoder_hidden = torch.cat((hidden[-1, :, :], hidden[-2, :, :]), dim = -1).unsqueeze(0)
        return encoder_output, encoder_hidden

decoder.py

import torch.nn as nn
import torch
import config
import torch.nn.functional as F
import random
import numpy as np


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.vocab_size = len(config.target_ws)
        self.hidden_size = config.chatbot_encoder_hidden_size * 2
        self.embedding = nn.Embedding(
            num_embeddings = len(config.target_ws),
            embedding_dim = config.chatbot_decoder_embedding_dim
        )
        self.gru = nn.GRU(
            input_size = config.chatbot_decoder_embedding_dim,
            # encoder中采用双向GRU, 在最后进行了双向拼接,decoder中hidden为encoder_hidden * 2
            # 以下注释中hidden_size, 均为decoder中hidden_size
            hidden_size = self.hidden_size,
            num_layers = config.chatbot_decoder_numlayers,
            batch_first = True,
            bidirectional = False
        )
        self.fc = nn.Linear(self.hidden_size ,self.vocab_size)


    def forward(self, encoder_hidden, target):
        """
        :param encoder_hidden:  [1, batch_size, hidden_size]
        :param target: [batch_size, seq_len]
        :return:
        """
        batch_size = encoder_hidden.size(1)
        # 初始化一个[batch_size, 1]的全SOS张量,作为decoder的第一个time step输入
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
        # encoder_hidden[1, batch_size, hidden_size] 作为decoder 第一个time step 的输入
        decoder_hidden = encoder_hidden
        # 初始化一个[batch_size, seq_len, vocab_size]的outputs 存储每个时间步结果
        decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, self.vocab_size]).to(config.device)
        # 判断是否使用teacher_forcing
        if random.random() > config.teacher_forcing:
            # 进行每个时间步的遍历
            for t in range(config.chatbot_target_max_len):
                output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
                # 获取每个时间步t的结果
                decoder_outputs[:, t, :] = output_t
                # 若不使用teacher_forcing 选取每个时间步的预测值最为下个时间步输入
                # index[batch_size, 1]
                """
                若使用max()
                value, index = output_t.max(dim = -1) # [batch_size]
                decoder_input = index.unsqueeze(1)
                若是argmax()
                index = output_t.argmax(dim = -1) # [batch_size]
                decoder_input = index.unsqueeze(1)
                """
                value, index = torch.topk(output_t, k = 1)
                # 需要保证decoder_input的输入为[batch_size, 1]
                decoder_input = index
        else:
            for t in range(config.chatbot_target_max_len):
                output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
                decoder_outputs[:, t, :] = output_t
                # 若使用teacher_forcing,采用真实值作为下次输入
                # 使得decoder_input的形状为[batch_size, 1]
                decoder_input = target[:,t].unsqueeze(1)
        return decoder_outputs, decoder_hidden


    def forward_step(self, decoder_input, decoder_hidden):
        """
        处理每个时间步逻辑
        :param decoder_input: [batch_size, 1]
        :param decoder_hidden: [1, batch_size, hidden_size]
        :return:
        """
        # [batch_size, 1] ->[batch_size, 1, embedding_dim]
        decoder_input_embeded = self.embedding(decoder_input)
        # decoder_output_t [batch_size, 1, embedding_dim] ->[batch_size, 1, hidden_size]
        # decoder_hidden_t [1,batch_size, hidden_size]
        decoder_output_t, decoder_hidden_t = self.gru(decoder_input_embeded, decoder_hidden)
        # 对decoder_output_t进行fcq前,需要对其进行形状改变 [batch_size, hidden_size]
        decoder_output_t = decoder_output_t.squeeze(1)
        # 进行fc -> [batch_size, vocab_size]
        decoder_output_t = F.log_softmax(self.fc(decoder_output_t), dim = -1)
        return decoder_output_t, decoder_hidden_t


    def evaluate(self, encoder_hidden):
        """
        评估, 和fowward逻辑类似
        :param encoder_hidden: encoder最后time step的隐藏状态 [1, batch_size, hidden_size]
        :return:
        """
        batch_size = encoder_hidden.size(1)
        # 初始化一个[batch_size, 1]的SOS张量,作为第一个time step的输出
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
        # encoder_hidden 作为decoder第一个时间步的hidden [1, batch_size, hidden_size]
        decoder_hidden = encoder_hidden
        # 初始化[batch_size, seq_len, vocab_size]的outputs 拼接每个time step结果
        decoder_outputs = torch.zeros((batch_size, config.chatbot_target_max_len, self.vocab_size)).to(config.device)
        # 初始化一个空列表,存储每次的预测序列
        predict_result = []
        # 对每个时间步进行更新
        for t in range(config.chatbot_target_max_len):
            decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            # 拼接每个time step,decoder_output_t [batch_size, vocab_size]
            decoder_outputs[:, t, :] = decoder_output_t
            # 由于是评估,需要每次都获取预测值
            index = torch.argmax(decoder_output_t, dim = -1)
            # 更新下一时间步的输入
            decoder_input = index.unsqueeze(1)
            # 存储每个时间步的预测序列
            predict_result.append(index.cpu().detach().numpy()) # [[batch], [batch]...] ->[seq_len, vocab_size]
        # 结果转换为ndarry,每行是一个预测结果即单个字对应的索引, 所有行为seq_len长度
        predict_result = np.array(predict_result).transpose()  # (batch_size, seq_len)的array
        return decoder_outputs, predict_result

seq2seq.py

import torch.nn as nn
from encoder import Encoder
from decoder import Decoder


class Seq2SeqModel(nn.Module):
    def __init__(self):
        super(Seq2SeqModel, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()


    def forward(self, input, target, input_length, target_length):
        encoder_outputs, encoder_hidden = self.encoder(input, input_length)
        decoder_outputs, decoder_hidden = self.decoder(encoder_hidden, target)
        return decoder_outputs


    def evaluation(self, input, input_length):
        encoder_outputs, encoder_hidden = self.encoder(input, input_length)
        decoder_outputs, predict_result = self.decoder.evaluate(encoder_hidden)
        return decoder_outputs, predict_result

模型训练与评估

train.py

import torch
from Seq2Seq import Seq2SeqModel
import torch.optim as optim
import config
from dataset import get_dataloader
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np
import os


model = Seq2SeqModel().to(config.device)
if os.path.exists('./model/chatbot_model.pkl'):
    model.load_state_dict(torch.load('./model/chatbot_model.pkl'))

optimizer = optim.Adam(model.parameters())
loss_list = []


def train(epoch):
    train_dataloader = get_dataloader(train = True)
    bar = tqdm(train_dataloader, desc = 'training', total = len(train_dataloader))
    model.train()
    for idx, (input, target, input_length, target_length) in enumerate(bar):
        input = input.to(config.device)
        target = target.to(config.device)
        input_length = input_length.to(config.device)
        target_length = target_length.to(config.device)
        optimizer.zero_grad()
        decoder_outputs = model(input, target, input_length, target_length)
        # 由于在decoder中进行log_softmax计算,计算损失需要F.nll_loss
        # decoder_outputs [batch_size, seq_len, vocab_size]
        # target [batch_size, seq_len]
        loss = F.nll_loss(decoder_outputs.view(-1, len(config.target_ws)), target.view(-1), ignore_index = config.target_ws.PAD)
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        bar.set_description('epoch:{}, idx{}/{}, loss:{:.6f}'.format(epoch + 1, idx, len(train_dataloader), np.mean(loss_list)))
        if idx % 100 == 0:
            torch.save(model.state_dict(), './model/chatbot_model.pkl')


if __name__ == '__main__':
    for i in range(100):
        train(i)

eval.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataset import get_dataloader
import config
import numpy as np
from Seq2Seq import Seq2SeqModel
import os
from tqdm import tqdm



model = Seq2SeqModel().to(config.device)
if os.path.exists('./model/chatbot_model.pkl'):
    model.load_state_dict(torch.load('./model/chatbot_model.pkl'))


def eval():
    model.eval()
    loss_list = []
    test_data_loader = get_dataloader(train = False)
    with torch.no_grad():
        bar = tqdm(test_data_loader, desc = 'testing', total = len(test_data_loader))
        for idx, (input, target, input_length, target_length) in enumerate(bar):
            input = input.to(config.device)
            target = target.to(config.device)
            input_length = input_length.to(config.device)
            target_length = target_length.to(config.device)
            # 获取模型的预测结果
            decoder_outputs, predict_result = model.evaluation(input, input_length)
            # 计算损失
            loss = F.nll_loss(decoder_outputs.view(-1, len(config.target_ws)), target.view(-1), ignore_index = config.target_ws.PAD)
            loss_list.append(loss.item())
            bar.set_description('idx{}:/{}, loss:{:.6f}'.format(idx, len(test_data_loader), np.mean(loss_list)))


if __name__ == '__main__':
    eval()

闲聊的使用 interface.py

from cut_sentence import cut
import torch
import config
from Seq2Seq import Seq2SeqModel
import os


# 模拟聊天场景,对用户输入进来的话进行回答
def interface():
    # 加载训练集好的模型
    model = Seq2SeqModel().to(config.device)
    assert os.path.exists('./model/chatbot_model.pkl') , '请先在train中对模型进行训练!'
    model.load_state_dict(torch.load('./model/chatbot_model.pkl'))
    model.eval()

    while True:
        # 输入进来的原始字符串,进行分词处理
        input_string = input('me>>:')
        if input_string == 'q':
            print('下次再聊')
            break
        input_cuted = cut(input_string, by_word = True)
        # 进行序列转换和tensor封装
        input_tensor = torch.LongTensor([config.input_ws.transfrom(input_cuted, max_len = config.chatbot_input_max_len)]).to(config.device)
        input_length_tensor = torch.LongTensor([len(input_cuted)]).to(config.device)
        # 获取预测结果
        outputs, predict = model.evaluation(input_tensor, input_length_tensor)
        # 进行序列转换文本
        result = config.target_ws.inverse_transform(predict[0])
        print('chatbot>>:', result)


if __name__ == '__main__':
    interface()

Decoder中加入Attention

attention.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import config


class Attention(nn.Module):
    def __init__(self, method):
        """
        attention 机制
        :param method:三种attention_weights 计算方法general, dot, concat
        """
        super(Attention, self).__init__()
        self.method = method
        self.hidden_size = config.chatbot_encoder_hidden_size
        assert self.method in ['dot', 'general', 'concat'], 'attention method error'
        if self.method == 'dot':
            # dot 为decoder_hidden 和encoder_outputs 直接进行矩阵乘法
            pass
        elif self.method == 'general':
            # general为对decoder_hidden 进行矩阵变换后,与encoder_outputs相乘
            self.Wa = nn.Linear(config.chatbot_encoder_hidden_size * 2, config.chatbot_encoder_hidden_size * 2,
                                bias=False)
        elif self.method == 'concat':
            self.Wa = nn.Linear(config.chatbot_encoder_hidden_size * 4, config.chatbot_encoder_hidden_size * 2,
                                bias=False)
            self.Va = nn.Linear(config.chatbot_encoder_hidden_size * 2, 1, bias = False)


    def forward(self, decoder_hidden, encoder_outputs):
        """
        进行三种运算得到attn_weights
        :param decoder_hidden: decoder每个时间步的隐藏状态[1, batch_size, en_hidden_size * 2]
        由于encoder中使用Bi-GRU,最后对双向hidden进行了拼接,因此de_hidden_size = en_hidden_size * 2
        未拼接前 encoder_hidden [1, batch_size, en_hidden_size]
        :param encoder_outputs:encoder最后的输出[batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        if self.method == 'dot':
            return self.dot_score(decoder_hidden, encoder_outputs)
        elif self.method == 'general':
            return self.general_score(decoder_hidden, encoder_outputs)
        elif self.method == 'concat':
            return self.concat_score(decoder_hidden, encoder_outputs)


    def dot_score(self, decoder_hidden, encoder_outputs):
        """
        dot 方法:直接对decoder_hidden 和 encoder_outputs进行矩阵乘法
        :param decoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param encoder_outputs:[batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        # 要进行矩阵乘法,需要改变decoder_hidden的形状为[batch_size, en_hidde_size * 2 , 1]
        # 乘法后形状为[batch_size, en_seq_len, 1]
        # squeeze去掉1的维度 为[batch_size, en_seq_len]
        # 最终对结果在en_seq_len维度上进行log_softmax
        return F.log_softmax(torch.bmm(encoder_outputs, decoder_hidden.permute(1, 2, 0)).squeeze(-1), dim = -1)


    def general_score(self, decoder_hidden, encoder_outputs):
        """
        general 方法:对decoder_hidden进行线性变换后与encoder_outputs进行矩阵乘法
        :param decoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        # 由于要进行线性变换, decoder_hidden首先变成二维张量,因此线性变换的输入维度为en_hidden_size * 2
        # [1, batch_size, en_hidden_size * 2]->[batch_size, en_hidden_size * 2]
        decoder_hidden = decoder_hidden.squeeze(0)
        # 由于要与encoder_outputs进行矩阵计算,需要将decoder_hidden的形状改变为dot中的形状
        # 即[batch_size, en_hidden_size * 2, 1],因此线性变换的输出维度为en_hidden_size * 2
        decoder_hidden = self.Wa(decoder_hidden).unsqueeze(-1)
        # 进行矩阵乘法[batch_size, en_seq_len, 1] ->squeeze [batch_size, en_seq_len]
        # torch.bmm 注意矩阵形状, 参数位置需要 根据矩阵乘法要求,不能写反
        return F.log_softmax(torch.bmm(encoder_outputs, decoder_hidden).squeeze(-1), dim = -1)


    def concat_score(self, decoder_hidden, encoder_outputs):
        """
        concat方法:decoder_hidden和encoder_outputs拼接,
        把这个结果使用tanh进行处理后的结果进行对齐(进行矩阵乘法,变换为需要的形状)计算之后,
        和encoder outputs进行矩阵乘法
        :param decoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        encoder_seq_len = encoder_outputs.size(1)
        batch_size = encoder_outputs.size(0)
        # repeat 沿着该维度重复指定次数
        # repeat(3,1,1)指在0维度重复3次,其他2个维度各一次
        # decoder_hidden [1, batch_size, en_hidden_size *2]->squeeze(0):[batch_size, en_hidden_size * 2]
        # ->repeat:[encoder_seq_len, batch_size, en_hidden_size * 2] ->transpose:[batch_size, encoder_seq_len, en_hidden_size * 2]
        decoder_hidden_repeated = decoder_hidden.squeeze(0).repeat(encoder_seq_len, 1, 1).transpose(1,0)
        # 对decoder_hidden_repeated和encoder_outputs进行拼接
        # cat:[batch_size, en_seq_len, en_hidden_size * 2 *2]
        # view[batch_size * en_seq_len, en_hidden_size * 4]
        # 因此第一个线性层输入维度为en_hidden_size * 4
        h_cated = torch.cat((decoder_hidden_repeated, encoder_outputs), dim = -1).view(batch_size * encoder_seq_len, -1)
        # 拼接后,需要进行线性变换及tanh和第二次线性变换最终将结果变为[batch_size, en_seq_len]
        # h_cated->Wa:[batch_size * en_seq_len, en_hidden_size *4] ->[batch_size * en_seq_len, en_hidden_size *2]
        # ->Va:[batch_size * en_seq_len, en_hidden_size *2] ->[batch_size * en_seq_len, 1]
        # ->view:[batch_size * en_seq_len, 1] ->[batch_size ,en_seq_len]
        attn_weight = self.Va(torch.tanh(self.Wa(h_cated))).view([batch_size, encoder_seq_len])
        return F.log_softmax(attn_weight, dim = -1)

修改decoder.py

import torch
import torch.nn as nn
import config
from attention import Attention
import random
import torch.nn.functional as F
import numpy as np


class Decoder_Attn(nn.Module):
    def __init__(self):
        super(Decoder_Attn, self).__init__()
        self.vocab_size = len(config.target_ws)
        # encoder中为双向GRU,hidden进行了双向拼接,为了attention计算方便
        # hidden_size = en_hidden_size
        self.hidden_size = config.chatbot_encoder_hidden_size
        self.embedding = nn.Embedding(
            num_embeddings = self.vocab_size,
            embedding_dim = config.chatbot_decoder_embedding_dim
        )
        self.gru = nn.GRU(
            input_size = config.chatbot_decoder_embedding_dim,
            hidden_size = self.hidden_size * 2,
            num_layers = config.chatbot_decoder_numlayers,
            batch_first = True,
            bidirectional = False
        )
        # 处理forward_step中decoder的每个时间步输出形状
        self.fc = nn.Linear(self.hidden_size * 2, self.vocab_size)
        # 实例化attn_weights
        self.attn = Attention(method = 'general')
        # self.attn 形状为
        self.fc_attn = nn.Linear(self.hidden_size * 4, self.hidden_size * 2)


    def forward(self, encoder_hidden, target, encoder_outputs):
        """
        :param encoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param target: [batch_size, en_seq_len]
        添加attention的decoder中, 对encoder_outputs进行利用与decoder_hidden计算得到注意力表示
        encoder_outputs为新增参数
        :param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        batch_size = encoder_hidden.size(1)
        # 初始化一个[batch_size, 1]的SOS作为decoder第一个时间步的输入decoder_input
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
        # 初始化一个[batch_size, de_seq_len, vocab_size]的张量拼接每个time step结果
        decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, self.vocab_size]).to(config.device)
        # encoder_hidden 作为decoder的第一个时间步的hidden
        decoder_hidden = encoder_hidden
        # 按照teacher_forcing的更新策略进行每个时间步的更新
        teacher_forcing = random.random() > 0.5
        if teacher_forcing:
            # 对每个时间步进行遍历
            for t in range(config.chatbot_target_max_len):
                # decoder_output_t [batch_size, vocab_size]
                # decoder_hidden [1, batch_size, en_hidden_size * 2]
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
                # 进行拼接每个时间步
                decoder_outputs[:, t, :] = decoder_output_t
                # 使用teacher_forcing,下一次采用真实结果
                # target[:, t] [batch_size]
                # decoder_input [batch_size, 1]
                decoder_input = target[:, t].unsqueeze(1)
        else:
            for t in range(config.chatbot_target_max_len):
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
                decoder_outputs[:, t, :] = decoder_output_t
                # 不使用teacher_forcing下次采用预测结果
                index = torch.argmax(decoder_output_t, dim = -1)
                # index [batch_size]
                # decoder_input [batch_size, 1]
                decoder_input = index.unsqueeze(-1)
        return decoder_outputs, decoder_hidden


    def forward_step(self, decoder_input, decoder_hidden, encoder_outputs):
        """
        每个时间步的处理
        :param decoder_input: [batch_size, 1]
        :param decoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param encoder_outputs [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        # 依次通过embedding、gru 和fc 最终返回log_softmax
        # decoder_input_embeded [batch_size, 1] -> [batch_size, 1, embedding_dim]
        decoder_input_embeded = self.embedding(decoder_input)
        decoder_output, decoder_hidden = self.gru(decoder_input_embeded, decoder_hidden)
        """通过decoder_hidden和encoder_outputs进行attention计算"""
        # 1.通过attention 计算attention weight
        # decoder_hidden [1, batch_size, en_hidden_size * 2]
        # encoder_outputs [batch_size, en_seq_len, en_hidden_size * 2]
        # attn_weight [batch_size, en_seq_len]
        attn_weight = self.attn(decoder_hidden, encoder_outputs)
        # 2.attn_weight与encoder_outputs 计算得到上下文向量
        # encoder_outputs[batch_size, en_seq_len, en_hidden_size * 2]
        # attn_weight [batch_size, en_seq_len]
        # 二者进行矩阵乘法,需要对attn_weight进行维度扩展->[batch_size, 1, en_seq_len]
        # context_vector [batch_size, 1, en_hidden_size * 2]
        context_vector = torch.bmm(attn_weight.unsqueeze(1), encoder_outputs)
        # 3.context_vector 与decoder每个时间步输出decoder_output进行拼接和线性变换,得到每个时间步的注意力结果输出
        # decoder_output[batch_size, 1, en_hidden_size * 2]
        # context_vector[batch_size, 1, en_hidden_size * 2]
        # 拼接后形状为 [batch_size, 1, en_hidden_size * 4]
        # 由于要进行全连接操作,需要对拼接后形状进行降维unsqueeze(1)
        # ->[batch_size, en_hidden_size * 4]
        # 且decoder每个时间步输出结果经过self.fc的维度为[batch_size, vocab_size]
        # 因此,self.attn的fc 输入输出维度为(en_hidden_size * 4, en_hidden_size * 2)
        # self.fc输入输出维度为(en_hidden_size * 2, vocab_size)
        # 注意:这里用torch.tanh,使用F.tanh会 'nn.functional.tanh is deprecated. Use torch.tanh instead'
        attn_result = torch.tanh(self.fc_attn(torch.cat((decoder_output, context_vector),dim = -1).squeeze(1)))
        # attn_result [batch_size, en_hidden_size * 2]
        # 经过self.fc后改变维度
        # decoder_output_t [batch_size, en_hidden_size * 2]->[batch_size, vocab_size]
        decoder_output_t = F.log_softmax(self.fc(attn_result), dim = -1)
        # decoder_hiddem [1, batch_size, en_hidden_size * 2]
        return decoder_output_t, decoder_hidden


    def evaluate(self, encoder_hidden, encoder_outputs):
        """
        评估逻辑
        :param encoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param encoder_outputs:  [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        batch_size = encoder_hidden.size(1)
        # 初始化decoder第一个时间步的输入和hidden和decoder输出
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
        decoder_outputs = torch.zeros((batch_size, config.chatbot_target_max_len, len(config.target_ws))).to(config.device)
        decoder_hidden = encoder_hidden
        # 初始化用于存储的预测序列
        predict_result = []
        for t in range(config.chatbot_target_max_len):
            decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
            decoder_outputs[:, t, :] = decoder_output_t
            index = torch.argmax(decoder_output_t, dim = -1)
            decoder_input = index.unsqueeze(-1)
            predict_result.append(index.cpu().detach().numpy())
        predict_result = np.array(predict_result).transpose()
        return decoder_outputs, predict_result

修改seq2seq.py

from encoder import Encoder
from decoder_attn import Decoder_Attn
import torch.nn as nn


class Seq2Seq_Attn(nn.Module):
    def __init__(self):
        super(Seq2Seq_Attn, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder_Attn()


    def forward(self, input, target, input_length, target_length):
        encoder_output, encoder_hidden = self.encoder(input, input_length)
        decoder_output, decoder_hidden = self.decoder(encoder_hidden, target, encoder_output)
        return decoder_output


    def evaluation_attn(self, input, input_length):
        encoder_output, encoder_hidden = self.encoder(input, input_length)
        decoder_output, predict_result = self.decoder.evaluate(encoder_hidden, encoder_output)
        return decoder_output, predict_result

train.py

from Seq2Seq_attn import Seq2Seq_Attn
import torch
import torch.optim as optim
import os
from tqdm import tqdm
from dataset import get_dataloader
import config
import torch.nn.functional as F


model = Seq2Seq_Attn().to(config.device)
if os.path.exists('./model/chatbot_attn_model.pkl'):
    model.load_state_dict(torch.load('./model/chatbot_attn_model.pkl'))
optimizer = optim.Adam(model.parameters())
loss_list = []

def train_attn(epoch):
    model.train()
    train_dataloader = get_dataloader(train = True)
    bar = tqdm(train_dataloader, desc = 'attn_training...', total = len(train_dataloader))
    for idx, (input, target, input_length, target_length) in enumerate(bar):
        input = input.to(config.device)
        target = target.to(config.device)
        input_length = input_length.to(config.device)
        target_length = target_length.to(config.device)
        optimizer.zero_grad()
        outputs = model(input, target, input_length, target_length)
        # outputs [batch_size, de_seq_len, vocab_size]
        # target [batch_size, de_seq_len]
        loss = F.nll_loss(outputs.view(-1, len(config.target_ws)), target.view(-1), ignore_index = config.target_ws.PAD)
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())
        bar.set_description('epoch:{},idx:{}/{},loss{:.6f}'.format(epoch + 1, idx, len(train_dataloader), loss.item()))
        if idx % 100 == 0:
            torch.save(model.state_dict(), './model/chatbot_attn_model.pkl')


if __name__ == '__main__':
    for i in range(10):
        train_attn(i)

eval.py

from Seq2Seq_attn import Seq2Seq_Attn
import torch
import config
import numpy as np
from dataset import get_dataloader
import torch.nn.functional as F
from tqdm import tqdm


model = Seq2Seq_Attn().to(config.device)
model.load_state_dict(torch.load('./model/chatbot_attn_model.pkl'))
loss_list = []


def eval():
    model.eval()
    test_dataloader = get_dataloader(train = False)
    bar = tqdm(test_dataloader, desc = 'attn_test...', total = len(test_dataloader))
    for idx, (input, target, input_length, target_length) in enumerate(bar):
        input = input.to(config.device)
        target = target.to(config.device)
        input_length = input_length.to(config.device)
        target_length = target_length.to(config.device)
        with torch.no_grad():
            output, predict_result = model.evaluation_attn(input, input_length)
            loss = F.nll_loss(output.view(-1, len(config.target_ws)), target.view(-1), ignore_index = config.target_ws.PAD)
            loss_list.append(loss.item())
            bar.set_description('idx:{}/{}, loss:{}'.format(idx, len(test_dataloader), np.mean(loss_list)))


if __name__ == '__main__':
    eval()

interface.py

from cut_sentence import cut
import torch
import config
from Seq2Seq_attn import Seq2Seq_Attn
import os


# 模拟聊天场景,对用户输入进来的话进行回答
def interface():
    # 加载训练集好的模型
    model = Seq2Seq_Attn().to(config.device)
    assert os.path.exists('./model/chatbot_attn_model.pkl') , '请先在train中对模型进行训练!'
    model.load_state_dict(torch.load('./model/chatbot_attn_model.pkl'))
    model.eval()

    while True:
        # 输入进来的原始字符串,进行分词处理
        input_string = input('me>>:')
        if input_string == 'q':
            print('下次再聊')
            break
        input_cuted = cut(input_string, by_word = True)
        # 进行序列转换和tensor封装
        input_tensor = torch.LongTensor([config.input_ws.transfrom(input_cuted, max_len = config.chatbot_input_max_len)]).to(config.device)
        input_length_tensor = torch.LongTensor([len(input_cuted)]).to(config.device)
        # 获取预测结果
        outputs, predict = model.evaluation_attn(input_tensor, input_length_tensor)
        # 进行序列转换文本
        result = config.target_ws.inverse_transform(predict[0])
        print('chatbot>>:', result)


if __name__ == '__main__':
    interface()

beam search

修改decoder.py

import torch
import torch.nn as nn
import config
from attention import Attention
import random
import torch.nn.functional as F
import numpy as np
import heapq


class Decoder_Attn_Beam(nn.Module):
    def __init__(self):
        super(Decoder_Attn_Beam, self).__init__()
        self.vocab_size = len(config.target_ws)
        # encoder中为双向GRU,hidden进行了双向拼接,为了attention计算方便
        # hidden_size = en_hidden_size
        self.hidden_size = config.chatbot_encoder_hidden_size
        self.embedding = nn.Embedding(
            num_embeddings = self.vocab_size,
            embedding_dim = config.chatbot_decoder_embedding_dim
        )
        self.gru = nn.GRU(
            input_size = config.chatbot_decoder_embedding_dim,
            hidden_size = self.hidden_size * 2,
            num_layers = config.chatbot_decoder_numlayers,
            batch_first = True,
            bidirectional = False
        )
        # 处理forward_step中decoder的每个时间步输出形状
        self.fc = nn.Linear(self.hidden_size * 2, self.vocab_size)
        # 实例化attn_weights
        self.attn = Attention(method = 'general')
        # self.attn 形状为
        self.fc_attn = nn.Linear(self.hidden_size * 4, self.hidden_size * 2)


    def forward(self, encoder_hidden, target, encoder_outputs):
        """
        :param encoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param target: [batch_size, en_seq_len]
        添加attention的decoder中, 对encoder_outputs进行利用与decoder_hidden计算得到注意力表示
        encoder_outputs为新增参数
        :param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        batch_size = encoder_hidden.size(1)
        # 初始化一个[batch_size, 1]的SOS作为decoder第一个时间步的输入decoder_input
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
        # 初始化一个[batch_size, de_seq_len, vocab_size]的张量拼接每个time step结果
        decoder_outputs = torch.zeros([batch_size, config.chatbot_target_max_len, self.vocab_size]).to(config.device)
        # encoder_hidden 作为decoder的第一个时间步的hidden
        decoder_hidden = encoder_hidden
        # 按照teacher_forcing的更新策略进行每个时间步的更新
        teacher_forcing = random.random() > 0.5
        if teacher_forcing:
            # 对每个时间步进行遍历
            for t in range(config.chatbot_target_max_len):
                # decoder_output_t [batch_size, vocab_size]
                # decoder_hidden [1, batch_size, en_hidden_size * 2]
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
                # 进行拼接每个时间步
                decoder_outputs[:, t, :] = decoder_output_t
                # 使用teacher_forcing,下一次采用真实结果
                # target[:, t] [batch_size]
                # decoder_input [batch_size, 1]
                decoder_input = target[:, t].unsqueeze(1)
        else:
            for t in range(config.chatbot_target_max_len):
                decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
                decoder_outputs[:, t, :] = decoder_output_t
                # 不使用teacher_forcing下次采用预测结果
                index = torch.argmax(decoder_output_t, dim = -1)
                # index [batch_size]
                # decoder_input [batch_size, 1]
                decoder_input = index.unsqueeze(-1)
        return decoder_outputs, decoder_hidden


    def forward_step(self, decoder_input, decoder_hidden, encoder_outputs):
        """
        每个时间步的处理
        :param decoder_input: [batch_size, 1]
        :param decoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param encoder_outputs [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        # 依次通过embedding、gru 和fc 最终返回log_softmax
        # decoder_input_embeded [batch_size, 1] -> [batch_size, 1, embedding_dim]
        decoder_input_embeded = self.embedding(decoder_input)
        decoder_output, decoder_hidden = self.gru(decoder_input_embeded, decoder_hidden)
        """通过decoder_hidden和encoder_outputs进行attention计算"""
        # 1.通过attention 计算attention weight
        # decoder_hidden [1, batch_size, en_hidden_size * 2]
        # encoder_outputs [batch_size, en_seq_len, en_hidden_size * 2]
        # attn_weight [batch_size, en_seq_len]
        attn_weight = self.attn(decoder_hidden, encoder_outputs)
        # 2.attn_weight与encoder_outputs 计算得到上下文向量
        # encoder_outputs[batch_size, en_seq_len, en_hidden_size * 2]
        # attn_weight [batch_size, en_seq_len]
        # 二者进行矩阵乘法,需要对attn_weight进行维度扩展->[batch_size, 1, en_seq_len]
        # context_vector [batch_size, 1, en_hidden_size * 2]
        context_vector = torch.bmm(attn_weight.unsqueeze(1), encoder_outputs)
        # 3.context_vector 与decoder每个时间步输出decoder_output进行拼接和线性变换,得到每个时间步的注意力结果输出
        # decoder_output[batch_size, 1, en_hidden_size * 2]
        # context_vector[batch_size, 1, en_hidden_size * 2]
        # 拼接后形状为 [batch_size, 1, en_hidden_size * 4]
        # 由于要进行全连接操作,需要对拼接后形状进行降维unsqueeze(1)
        # ->[batch_size, en_hidden_size * 4]
        # 且decoder每个时间步输出结果经过self.fc的维度为[batch_size, vocab_size]
        # 因此,self.attn的fc 输入输出维度为(en_hidden_size * 4, en_hidden_size * 2)
        # self.fc输入输出维度为(en_hidden_size * 2, vocab_size)
        # 注意:这里用torch.tanh,使用F.tanh会 'nn.functional.tanh is deprecated. Use torch.tanh instead'
        attn_result = torch.tanh(self.fc_attn(torch.cat((decoder_output, context_vector),dim = -1).squeeze(1)))
        # attn_result [batch_size, en_hidden_size * 2]
        # 经过self.fc后改变维度
        # decoder_output_t [batch_size, en_hidden_size * 2]->[batch_size, vocab_size]
        decoder_output_t = F.log_softmax(self.fc(attn_result), dim = -1)
        # decoder_hiddem [1, batch_size, en_hidden_size * 2]
        return decoder_output_t, decoder_hidden


    def evaluate(self, encoder_hidden, encoder_outputs):
        """
        评估逻辑
        :param encoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param encoder_outputs:  [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        batch_size = encoder_hidden.size(1)
        # 初始化decoder第一个时间步的输入和hidden和decoder输出
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
        decoder_outputs = torch.zeros((batch_size, config.chatbot_target_max_len, len(config.target_ws))).to(config.device)
        decoder_hidden = encoder_hidden
        # 初始化用于存储的预测序列
        predict_result = []
        for t in range(config.chatbot_target_max_len):
            decoder_output_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)
            decoder_outputs[:, t, :] = decoder_output_t
            index = torch.argmax(decoder_output_t, dim = -1)
            decoder_input = index.unsqueeze(-1)
            predict_result.append(index.cpu().detach().numpy())
        predict_result = np.array(predict_result).transpose()
        return decoder_outputs, predict_result


    def evaluate_beam(self, encoder_hidden, encoder_outputs):
        """
        使用beam search的评估
        :param encoder_hidden: [1, batch_size, en_hidden_size * 2]
        :param encoder_outputs: [batch_size, en_seq_len, en_hidden_size * 2]
        :return:
        """
        # 注意:在beam search的过程中,batch_size 只能为1
        batch_size = encoder_hidden.size(0)
        assert batch_size == 1, 'batch_size 不为1'
        # 初始化decoder的输入和输出及隐藏状态
        decoder_input = torch.LongTensor([[config.target_ws.SOS]] * batch_size).to(config.device)
        decoder_hidden = encoder_hidden

        # 实例化首次beam
        prev_beam = Beam()
        # 第一次需要的输入数据,保存在堆中
        prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden)
        # 循环比较堆中前一次和后一次的数据
        while True:
            # 实例化当前的beam
            cur_beam = Beam()
            # 取出堆中的数据,判断是否遇到EOS,若是,则添加进堆中,若不是则进行forward_step
            for _prob, _complete, _seq_list, _decoder_input, _decoder_hidden in prev_beam:
                if _complete:
                    cur_beam.add(_prob, _complete, _seq_list, _decoder_input, _decoder_hidden)
                else:
                    # decoder_output_t [1, vocab_size]
                    # decoder_hidden [1, 1, en_hidden_size * 2]
                    decoder_output_t, decoder_hidden = self.forward_step(_decoder_input, _decoder_hidden, encoder_outputs)
                    value, index = torch.topk(decoder_output_t, k = config.beam_width)
                    for val, idx in zip(value[0], index[0]):
                        cur_prob = _prob * val.item()
                        decoder_input = torch.LongTensor([[idx.item()]]).to(config.device)
                        cur_seq_list = _seq_list + [decoder_input]
                        if idx == config.target_ws.EOS:
                            cur_complete = True
                        else:
                            cur_complete = False
                        cur_beam.add(cur_prob, cur_complete, cur_seq_list, decoder_input, decoder_hidden)
            # 获取新的堆中的优先级最高(概率最大)的数据,判断数据是否是EOS结尾或者是否达到最大长度,如果是,停止迭代
            best_prob, best_complete, best_seq_list, _, _ = max(cur_beam)
            if best_complete or len(best_seq_list) - 1 == config.chatbot_target_max_len:
                # 对结果进行基础的处理,共后续转化为文字使用
                best_seq_list = [i.item() for i in best_seq_list]
                if best_seq_list[0] == config.target_ws.SOS:
                    best_seq_list = best_seq_list[1:]
                if best_seq_list[-1] == config.target_ws.EOS:
                    best_seq_list = best_seq_list[:-1]
                return best_seq_list
            else:
                # 则重新遍历新的堆中的数据
                prev_beam = cur_beam


class Beam(object):
    # 采用堆实现beam search
    def __init__(self):
        self.heapq = list()  # 使用列表保存数据
        self.beam_width = config.beam_width # 每次返回最大的beam_width个结果


    def add(self, prob, complete, seq_list, decoder_input, decoder_hidden):
        """
        添加数据,同时判断总的数据个数,多则删除
        :param prob:概率乘积
        :param complete:最后一个是否为EOS
        :param seq_list:所有token的列表
        :param decoder_input:下一次进行解码的输入,通过前一次获得
        :param decoder_hidden:下一次进行解码的hidden,通过前一次获得
        :return:
        """
        heapq.heappush(self.heapq, [prob, complete, seq_list, decoder_input, decoder_hidden])
        # 保证每次保存beam_width个数据
        if len(self.heapq) > self.beam_width:
            heapq.heappop(self.heapq)


    def __iter__(self):
        for item in self.heapq:
            yield item

seq2seq.py

from encoder import Encoder
from decoder_beam import Decoder_Attn_Beam
import torch.nn as nn


class Seq2Seq_Attn_Beam(nn.Module):
    def __init__(self):
        super(Seq2Seq_Attn_Beam, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder_Attn_Beam()


    def forward(self, input, target, input_length, target_length):
        encoder_output, encoder_hidden = self.encoder(input, input_length)
        decoder_output, decoder_hidden = self.decoder(encoder_hidden, target, encoder_output)
        return decoder_output


    def evaluation_attn(self, input, input_length):
        encoder_output, encoder_hidden = self.encoder(input, input_length)
        decoder_output, predict_result = self.decoder.evaluate(encoder_hidden, encoder_output)
        return decoder_output, predict_result


    def evaluation_beam(self, input, input_length):
        encoder_output, encoder_hidden = self.encoder(input, input_length)
        best_seq = self.decoder.evaluate_beam(encoder_hidden, encoder_output)
        return best_seq

interface.py

from cut_sentence import cut
import torch
import config
from Seq2Seq_beam import Seq2Seq_Attn_Beam
import os


# 模拟聊天场景,对用户输入进来的话进行回答
def interface():
    # 加载训练集好的模型
    model = Seq2Seq_Attn_Beam().to(config.device)
    assert os.path.exists('./model/chatbot_attn_model.pkl') , '请先在train中对模型进行训练!'
    model.load_state_dict(torch.load('./model/chatbot_attn_model.pkl'))
    model.eval()

    while True:
        # 输入进来的原始字符串,进行分词处理
        input_string = input('me>>:')
        if input_string == 'q':
            print('下次再聊')
            break
        input_cuted = cut(input_string, by_word = True)
        # 进行序列转换和tensor封装
        input_tensor = torch.LongTensor([config.input_ws.transfrom(input_cuted, max_len = config.chatbot_input_max_len)]).to(config.device)
        input_length_tensor = torch.LongTensor([len(input_cuted)]).to(config.device)
        # 获取预测结果
        predict = model.evaluation_beam(input_tensor, input_length_tensor)
        # 进行序列转换文本
        # beam_search中 返回本身为一个序列列表
        result = config.target_ws.inverse_transform(predict)
        print('chatbot>>:', result)


if __name__ == '__main__':
    interface()

你可能感兴趣的:(笔记,自然语言处理,深度学习,pytorch)