前言
上篇写过一个机器学习写唐诗的实验,这次我们搞个稍微复杂些的,实现一个聊天机器人,也是基于腾讯云实验室的一篇教程,有些部分做了改动,大部分时间都用在了环境的适配上面。开始本地是在Mac环境,单独依靠CPU训练,比较慢。后来找了个配置比较好的机器, 6核心12线程,效果好一些。总结来说,机器学习相关有两个重点,一个是基础的训练资源,包括对原始数据的清洗处理和规范化,训练中其实模型是没有很大区别的。其次,是好的机器配置,资源有限,没有上GPU。这次实验,本地训练大概半天到4000步的时候,还只是个复读机,换了高配机器1天左右就可以到30万左右,两天到70万,基本达到损失率稳定(30万就可以)。
以下是本地机器的配置,奈何效果不行
MacBook Pro (13-inch, 2017, Four Thunderbolt 3 Ports)
10.13.6 (17G65)16 GB 2133 MHz LPDDR3
3.1 GHz Intel Core i5
注意事项
:
强烈建议使用virtualenv配置python,简单而且不会对本地运行环境造成影响。
同时需要安装好TensorFlow环境
过程步骤
实验内容
首先进行数据的清洗,处理。提取ask和answer数据,并得到字典,以及做向量化处理。训练数据可以使用本次实验链接里的,也可以使用网上的小黄鸡等等语料。注意这里的字典之前查的资料是满足3000左右的常用汉字就可以,是在语料中找到常用字。
-
模型学习部分。
这里引用了seq2seq的部分,单独有一些修改。之前下载实验中提供的训练了30万次左右的模型直接进行对话,但是本地一直提示错误。最终选择了自己训练,保存了完整的checkpoint文件,可以启动程序。如图最终训练在71万次左右,其实30万左右损失率基本就已经不变了,如果能提供更优化的语料应该效果会更好。后续有链接提供所有资料,可以直接下载。
-
模拟对话,这部分是最终的成果,启动本地依赖,加载训练模型之后就可以对话了,效果看图,可以看到有些句子还是可以对上的,一问一答,有些幼稚。
代码部分
- 数据整理和向量化 generate.py
# -*- coding:utf-8 -*-
from io import open
import random
import tensorflow as tf
# version tf 1.12 2018-12-08 22:22:08
PAD = "PAD"
GO = "GO"
EOS = "EOS"
UNK = "UNK"
START_VOCAB = [PAD, GO, EOS, UNK]
PAD_ID = 0 # 填充
GO_ID = 1 # 开始标志
EOS_ID = 2 # 结束标志
UNK_ID = 3 # 未知字符
_buckets = [(10, 15), (20, 25), (40, 50), (80, 100)]
units_num = 256
num_layers = 3
max_gradient_norm = 5.0
batch_size = 50
learning_rate = 0.5
learning_rate_decay_factor = 0.97
train_encode_file = "data/train_encode"
train_decode_file = "data/train_decode"
test_encode_file = "data/test_encode"
test_decode_file = "data/test_decode"
vocab_encode_file = "data/vocab_encode"
vocab_decode_file = "data/vocab_decode"
train_encode_vec_file = "data/train_encode_vec"
train_decode_vec_file = "data/train_decode_vec"
test_encode_vec_file = "data/test_encode_vec"
test_decode_vec_file = "data/test_decode_vec"
def is_chinese(sentence):
flag = True
if len(sentence) < 2:
flag = False
return flag
for uchar in sentence:
if (uchar == ',' or uchar == '。' or
uchar == '~' or uchar == '?' or
uchar == '!'):
flag = True
elif '一' <= uchar <= '鿿':
flag = True
else:
flag = False
break
return flag
def get_chatbot():
f = open("data/chat.conv", "r", encoding="utf-8")
train_encode = open(train_encode_file, "w", encoding="utf-8")
train_decode = open(train_decode_file, "w", encoding="utf-8")
test_encode = open(test_encode_file, "w", encoding="utf-8")
test_decode = open(test_decode_file, "w", encoding="utf-8")
vocab_encode = open(vocab_encode_file, "w", encoding="utf-8")
vocab_decode = open(vocab_decode_file, "w", encoding="utf-8")
encode = list()
decode = list()
chat = list()
print("start load source data...")
step = 0
for line in f.readlines():
line = line.strip('\n').strip()
if not line:
continue
if line[0] == "E":
if step % 1000 == 0:
print("step:%d" % step)
step += 1
if (len(chat) == 2 and is_chinese(chat[0]) and is_chinese(chat[1]) and
not chat[0] in encode and not chat[1] in decode):
encode.append(chat[0])
decode.append(chat[1])
chat = list()
elif line[0] == "M":
L = line.split(' ')
if len(L) > 1:
chat.append(L[1])
encode_size = len(encode)
if encode_size != len(decode):
raise ValueError("encode size not equal to decode size")
test_index = random.sample([i for i in range(encode_size)], int(encode_size * 0.2))
print("divide source into two...")
step = 0
for i in range(encode_size):
if step % 1000 == 0:
print("%d" % step)
step += 1
if i in test_index:
test_encode.write(encode[i] + "\n")
test_decode.write(decode[i] + "\n")
else:
train_encode.write(encode[i] + "\n")
train_decode.write(decode[i] + "\n")
vocab_encode_set = set(''.join(encode))
vocab_decode_set = set(''.join(decode))
print("get vocab_encode...")
step = 0
for word in vocab_encode_set:
if step % 1000 == 0:
print("%d" % step)
step += 1
vocab_encode.write(word + "\n")
print("get vocab_decode...")
step = 0
for word in vocab_decode_set:
print("%d" % step)
step += 1
vocab_decode.write(word + "\n")
def gen_chatbot_vectors(input_file, vocab_file, output_file):
vocab_f = open(vocab_file, "r", encoding="utf-8")
output_f = open(output_file, "w")
input_f = open(input_file, "r", encoding="utf-8")
words = list()
for word in vocab_f.readlines():
word = word.strip('\n').strip()
words.append(word)
word_to_id = {word: i for i, word in enumerate(words)}
to_id = lambda word: word_to_id.get(word, UNK_ID)
print("get %s vectors" % input_file)
step = 0
for line in input_f.readlines():
if step % 1000 == 0:
print("step:%d" % step)
step += 1
line = line.strip('\n').strip()
vec = map(to_id, line)
output_f.write(' '.join([str(n) for n in vec]) + "\n")
def get_vectors():
gen_chatbot_vectors(train_encode_file, vocab_encode_file, train_encode_vec_file)
gen_chatbot_vectors(train_decode_file, vocab_decode_file, train_decode_vec_file)
gen_chatbot_vectors(test_encode_file, vocab_encode_file, test_encode_vec_file)
gen_chatbot_vectors(test_decode_file, vocab_decode_file, test_decode_vec_file)
def get_vocabs(vocab_file):
words = list()
with open(vocab_file, "r", encoding="utf-8") as vocab_f:
for word in vocab_f:
words.append(word.strip('\n').strip())
id_to_word = {i: word for i, word in enumerate(words)}
word_to_id = {v: k for k, v in id_to_word.items()}
vocab_size = len(id_to_word)
return id_to_word, word_to_id, vocab_size
def read_data(source_path, target_path, max_size=None):
data_set = [[] for _ in _buckets]
with tf.gfile.GFile(source_path, mode="r") as source_file:
with tf.gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
counter = 0
while source and target and (not max_size or counter < max_size):
counter += 1
source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
target_ids.append(EOS_ID)
for bucket_id, (source_size, target_size) in enumerate(_buckets):
if len(source_ids) < source_size and len(target_ids) < target_size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = source_file.readline(), target_file.readline()
return data_set
# run
#获取 ask、answer 数据并生成字典
# get_chatbot()
#训练数据转化为数字表示
# get_vectors()
- 学习模型
限制太长无法发布,只能在最后的链接获取了
seq2seq.py
seq2seq_model.py
- 训练模块
可以改小配置中的step部分,简单验证下效果。这里有些改动,加了间隔一定步骤之后,保存checkpoint到本地的功能,防止中间如果有异常,比如断电或者不小心关闭程序或者其他原因造成程序崩溃,导致前功尽弃。
train_chat.py
# -*- coding:utf-8 -*-
import generate as generate_chat
import seq2seq_model as seq2seq_model
import tensorflow as tf
import numpy as np
import logging
import logging.handlers
if __name__ == '__main__':
_, _, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
_, _, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
train_set = generate_chat.read_data(generate_chat.train_encode_vec_file, generate_chat.train_decode_vec_file)
test_set = generate_chat.read_data(generate_chat.test_encode_vec_file, generate_chat.test_decode_vec_file)
train_bucket_sizes = [len(train_set[i]) for i in range(len(generate_chat._buckets))]
train_total_size = float(sum(train_bucket_sizes))
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
with tf.Session(config=cpu_config) as sess:
model = seq2seq_model.Seq2SeqModel(source_vocab_size,
target_vocab_size,
generate_chat._buckets,
generate_chat.units_num,
generate_chat.num_layers,
generate_chat.max_gradient_norm,
generate_chat.batch_size,
generate_chat.learning_rate,
generate_chat.learning_rate_decay_factor,
use_lstm=True)
ckpt = tf.train.get_checkpoint_state('./mytrain')
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)
else:
print("Created model with fresh parameters.")
sess.run(tf.global_variables_initializer())
loss = 0.0
step = 0
previous_losses = []
while True:
random_number_01 = np.random.random_sample()
bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)
print("step:%d,loss:%f" % (step, step_loss))
loss += step_loss / 2000
step += 1
if step % 1000 == 0:
print("step:%d,per_loss:%f" % (step, loss))
if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
sess.run(model.learning_rate_decay_op)
previous_losses.append(loss)
model.saver.save(sess, "mytrain/chatbot.ckpt", global_step=model.global_step)
loss = 0.0
if step % 5000 == 0:
for bucket_id in range(len(generate_chat._buckets)):
if len(test_set[bucket_id]) == 0:
continue
encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id,
True)
print("bucket_id:%d,eval_loss:%f" % (bucket_id, eval_loss))
- 对话模块
chat.py
# -*- coding:utf-8 -*-
import generate as generate_chat
import seq2seq_model as seq2seq_model
import tensorflow as tf
import numpy as np
import sys
if __name__ == '__main__':
source_id_to_word, source_word_to_id, source_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_encode_file)
target_id_to_word, target_word_to_id, target_vocab_size = generate_chat.get_vocabs(generate_chat.vocab_decode_file)
to_id = lambda word: source_word_to_id.get(word, generate_chat.UNK_ID)
cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
with tf.Session(config=cpu_config) as sess:
model = seq2seq_model.Seq2SeqModel(source_vocab_size,
target_vocab_size,
generate_chat._buckets,
generate_chat.units_num,
generate_chat.num_layers,
generate_chat.max_gradient_norm,
1,
generate_chat.learning_rate,
generate_chat.learning_rate_decay_factor,
forward_only=True,
use_lstm=True)
#model.saver.restore(sess, "model/chatbot.ckpt-317000")
model.saver.restore(sess, "mytrain/chatbot.ckpt-717000")
while True:
sys.stdout.write("ask > ")
sys.stdout.flush()
sentence = sys.stdin.readline().strip('\n')
flag = generate_chat.is_chinese(sentence)
if not sentence or not flag:
print("请输入纯中文")
continue
sentence_vec = list(map(to_id, sentence))
bucket_id = len(generate_chat._buckets) - 1
if len(sentence_vec) > generate_chat._buckets[bucket_id][0]:
print("sentence too long max:%d" % generate_chat._buckets[bucket_id][0])
exit(0)
for i, bucket in enumerate(generate_chat._buckets):
if bucket[0] >= len(sentence_vec):
bucket_id = i
break
encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(sentence_vec, [])]},
bucket_id)
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
if generate_chat.EOS_ID in outputs:
outputs = outputs[:outputs.index(generate_chat.EOS_ID)]
answer = "".join([tf.compat.as_str(target_id_to_word[output]) for output in outputs])
print("answer > " + answer)
注意
这里在train_chat.py 和 chat.py中,tf.session
有个配置改动,限制了使用的CPU数,在Ubuntu下如果没有限制,会造成TF占用所有的CPU资源,导致系统卡死,具体数值根据CPU核心数设置。
代码如下:
cpu_config = tf.ConfigProto(intra_op_parallelism_threads=6,inter_op_parallelism_threads=6,device_count={'CPU':6})
with tf.Session(config=cpu_config) as sess:
结语
感谢阅读,最后放上实验的实际地址和我自己训练的所有资源,本地实验在mac tf 1.12.0 和 python3.6.7,以及Ubuntu tf.1.12.0 和 python3.5环境下都正常,再次建议在virtualenv环境下。
实验链接(时间过久可能失效):https://cloud.tencent.com/developer/labs/lab/10406
本地实验资源:https://iss.igosh.com/share/201903/tencent-me.tar.gz