“”"
author:lei
function: 机器人写唐诗
“”"
import tensorflow as tf
import numpy as np
import collections
import re
def extract_data(file_path):
poems = []
with open(file_path, “r”, encoding=“utf8”) as f:
# for line in f.readlines():
# 读取诗词文件,使用符号进行分割为列表
content_list = f.read().strip().split(“◎”)
# print(content_list)
# 对每一卷诗词进行处理
for content in content_list:
# 将换行符去掉,将tab去掉
content = content.replace("\n", “”).replace("\u3000", “”)
# 使用分隔符进行分割 剩下所有的诗词,没有卷
content = content.split("】")
if len(content) == 2:
# 取每首诗的内容
values = content[1]
# 将有(的诗去掉
if “(” in values or “?” in values or “卷” in values:
continue
# print(values)
# 使用列表保存所有诗词的内容
values = “G” + values + “F”
poems.append(values)
poems = [poem for poem in poems if len(poem) < 200]
return poems
def word_to_vec(poems):
pading_list = [""]
counter = collections.Counter("".join(poems))
counter_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*counter_pairs))
pading_list.extend(words)
word_to_id = dict(zip(pading_list, range(len(pading_list))))
id_to_word = dict(zip(range(len(pading_list)), pading_list))
return word_to_id, id_to_word
def padding_data(data):
# 对特征值和labels进行填充 150
line_list = []
for line in data:
if len(line) > 150:
line_list.append(np.array(line[:150]))
else:
line_list.append(np.hstack((np.array(line), np.zeros(150, dtype=np.int)))[:150])
return line_list
def data_main(file_path):
poems = extract_data(file_path)
word_to_id, id_to_word = word_to_vec(poems)
# print(len(poems)) # 39094
# print(word_to_id) # 6701 + 1
poem_list = []
for poem in poems:
temp = []
for word in poem:
temp.append(word_to_id[word])
poem_list.append(temp)
# print(len(poems)) # 39094
features = [i[:-1] for i in poem_list]
labels = [i[1:] for i in poem_list]
features = np.array(padding_data(features))
labels = np.array(padding_data(labels))
return features, labels, word_to_id, id_to_word
class PoemRobot(object):
def init(self, file_path, save_path, training):
features, labels, word_to_id, id_to_word = data_main(file_path)
self.batch_size = 10
self.n_steps = 150
self.n_layers = 3
self.word_size = 6702
self.num_neutrals = 200 # 词向量的维度为200
self.learning_rate = 0.01
self.training = training
self.keep_prob = 0.5
self.poem_nums = 39094
self.epoch_size = self.poem_nums // self.batch_size - 1 # 3900
self.frequences = 10
tf.reset_default_graph()
with tf.compat.v1.name_scope("data"):
x = tf.compat.v1.placeholder(tf.int32, [self.batch_size, self.n_steps])
y_true = tf.compat.v1.placeholder(tf.float32, [self.batch_size, self.n_steps, self.word_size])
with tf.compat.v1.name_scope("embedding"):
embedding = tf.compat.v1.Variable(tf.compat.v1.random_uniform([self.word_size, self.num_neutrals], -1.0, 1.0))
inputs = tf.compat.v1.nn.embedding_lookup(embedding, x)
if self.training is True:
inputs = tf.compat.v1.nn.dropout(inputs, keep_prob=self.keep_prob)
# 构建lstm模型
with tf.compat.v1.variable_scope("lstm_model"):
cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(num_units=self.num_neutrals)
if self.training is True:
tf.compat.v1.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self.keep_prob)
multi_cells = tf.compat.v1.nn.rnn_cell.MultiRNNCell([cell for i in range(self.n_layers)])
initial_state = multi_cells.zero_state(self.batch_size, dtype=tf.float32)
outputs, states = tf.compat.v1.nn.dynamic_rnn(multi_cells, inputs, initial_state=initial_state, dtype=tf.compat.v1.float32)
outputs = tf.reshape(outputs, [-1, self.num_neutrals])
# print(outputs) # Tensor("lstm_model/Reshape:0", shape=(9600, 200), dtype=float32)
# 进行全连接层
with tf.compat.v1.variable_scope("correct"):
weight = tf.compat.v1.Variable(tf.truncated_normal([self.num_neutrals, self.word_size], stddev=0.1))
bias = tf.compat.v1.Variable(tf.truncated_normal([self.word_size], stddev=0.1))
logits = tf.compat.v1.nn.xw_plus_b(outputs, weight, bias)
# print(logits) # Tensor("correct/xw_plus_b:0", shape=(9600, 6702), dtype=float32)
# 进行损失计算
with tf.compat.v1.name_scope("loss"):
y_true_2 = tf.reshape(y_true, [-1, self.word_size])
# print(y_true) # Tensor("loss/Reshape:0", shape=(9600, 6702), dtype=float32)
y_predict = tf.argmax(tf.argmax(logits, 1))
loss = tf.reduce_mean(tf.compat.v1.nn.softmax_cross_entropy_with_logits(labels=y_true_2, logits=logits))
with tf.compat.v1.name_scope("train_step"):
train_op = tf.compat.v1.train.AdamOptimizer(self.learning_rate).minimize(loss)
with tf.compat.v1.name_scope("acc"):
equal = tf.equal(tf.argmax(y_true_2, 1), tf.argmax(logits, 1))
accuracy = tf.reduce_mean(tf.cast(equal, tf.float32))
init_op = tf.global_variables_initializer()
saver = tf.compat.v1.train.Saver()
if self.training is True:
with tf.compat.v1.Session() as sess:
sess.run(init_op)
new_acc = 0
t = 0
for frequence in range(self.frequences):
for epoch in range(self.epoch_size):
t += 1
x_batch = features[epoch * self.batch_size: (epoch+1) * self.batch_size]
y_batch = labels[epoch * self.batch_size: (epoch+1) * self.batch_size]
y_batch = tf.one_hot(y_batch, depth=self.word_size, axis=2)
sess.run(train_op, feed_dict={x: x_batch, y_true: y_batch.eval()})
if epoch % 50 == 0:
acc_val = sess.run(accuracy, feed_dict={x: x_batch, y_true: y_batch.eval()})
print("frequence:{}, epoch: {}, acc: {}".format(frequence, epoch, acc_val))
if acc_val > new_acc and frequence > 7:
new_acc = acc_val
saver.save(sess, save_path, global_step=t)
print("模型更好,保存成功!")
else:
saver.save(sess, save_path, global_step=t)
print("模型保存成功!")
else:
# data = input("请输入您的第一个字:")
data = "敬"
data_num = 0
if data in word_to_id.keys():
data_num = word_to_id[data]
else:
print("抱歉,您的字不在词库中!")
exit(0)
init_op = tf.compat.v1.global_variables_initializer()
input_data = tf.placeholder(tf.int32, [1, None])
with tf.compat.v1.Session() as sess:
sess.run(init_op)
saver.restore(sess, save_path)
x = np.array([list(map(word_to_id.get, "G"))])
pridict = sess.run(y_predict, feed_dict={input_data: x})
if data:
word = data
else:
word = id_to_word[pridict]
poem = ""
while word != "F":
poem += word
x = np.zeros(1, 1)
x[0, 0] = word_to_id[word]
predict = sess.run(y_predict, feed_dict={input_data: x})
word = id_to_word[predict]
print("生成的诗句为:", poem)
if name == ‘main’:
file_path = “/home/aistudio/poem/data/tang_all.txt”
save_path = “/home/aistudio/poem/model/”
# file_path = “./data/tang_all.txt”
poem = PoemRobot(file_path, save_path, False)