机器人写唐诗,模型训练

coding=utf-8

“”"
author:lei
function: 机器人写唐诗
“”"

import tensorflow as tf
import numpy as np
import collections
import re

提取数据

def extract_data(file_path):
poems = []
with open(file_path, “r”, encoding=“utf8”) as f:
# for line in f.readlines():
# 读取诗词文件,使用符号进行分割为列表
content_list = f.read().strip().split(“◎”)
# print(content_list)
# 对每一卷诗词进行处理
for content in content_list:
# 将换行符去掉,将tab去掉
content = content.replace("\n", “”).replace("\u3000", “”)
# 使用分隔符进行分割 剩下所有的诗词,没有卷
content = content.split("】")
if len(content) == 2:
# 取每首诗的内容
values = content[1]
# 将有(的诗去掉
if “(” in values or “?” in values or “卷” in values:
continue
# print(values)
# 使用列表保存所有诗词的内容
values = “G” + values + “F”
poems.append(values)
poems = [poem for poem in poems if len(poem) < 200]
return poems

def word_to_vec(poems):
pading_list = [""]
counter = collections.Counter("".join(poems))
counter_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*counter_pairs))
pading_list.extend(words)
word_to_id = dict(zip(pading_list, range(len(pading_list))))
id_to_word = dict(zip(range(len(pading_list)), pading_list))
return word_to_id, id_to_word

def padding_data(data):
# 对特征值和labels进行填充 150
line_list = []
for line in data:
if len(line) > 150:
line_list.append(np.array(line[:150]))
else:
line_list.append(np.hstack((np.array(line), np.zeros(150, dtype=np.int)))[:150])
return line_list

def data_main(file_path):
poems = extract_data(file_path)
word_to_id, id_to_word = word_to_vec(poems)

# print(len(poems))  # 39094
# print(word_to_id)  # 6701 + 1

poem_list = []

for poem in poems:
    temp = []
    for word in poem:
        temp.append(word_to_id[word])
    poem_list.append(temp)

# print(len(poems))  # 39094

features = [i[:-1] for i in poem_list]
labels = [i[1:] for i in poem_list]

features = np.array(padding_data(features))
labels = np.array(padding_data(labels))

return features, labels, word_to_id, id_to_word

构建模型

class PoemRobot(object):
def init(self, file_path, save_path, training):
features, labels, word_to_id, id_to_word = data_main(file_path)
self.batch_size = 10
self.n_steps = 150
self.n_layers = 3
self.word_size = 6702
self.num_neutrals = 200 # 词向量的维度为200
self.learning_rate = 0.01
self.training = training
self.keep_prob = 0.5
self.poem_nums = 39094
self.epoch_size = self.poem_nums // self.batch_size - 1 # 3900
self.frequences = 10

    tf.reset_default_graph()
    with tf.compat.v1.name_scope("data"):
        x = tf.compat.v1.placeholder(tf.int32, [self.batch_size, self.n_steps])
        y_true = tf.compat.v1.placeholder(tf.float32, [self.batch_size, self.n_steps, self.word_size])

    with tf.compat.v1.name_scope("embedding"):
        embedding = tf.compat.v1.Variable(tf.compat.v1.random_uniform([self.word_size, self.num_neutrals], -1.0, 1.0))
        inputs = tf.compat.v1.nn.embedding_lookup(embedding, x)

        if self.training is True:
            inputs = tf.compat.v1.nn.dropout(inputs, keep_prob=self.keep_prob)

    # 构建lstm模型
    with tf.compat.v1.variable_scope("lstm_model"):
        cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(num_units=self.num_neutrals)
        if self.training is True:
            tf.compat.v1.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self.keep_prob)

        multi_cells = tf.compat.v1.nn.rnn_cell.MultiRNNCell([cell for i in range(self.n_layers)])

        initial_state = multi_cells.zero_state(self.batch_size, dtype=tf.float32)

        outputs, states = tf.compat.v1.nn.dynamic_rnn(multi_cells, inputs, initial_state=initial_state, dtype=tf.compat.v1.float32)

        outputs = tf.reshape(outputs, [-1, self.num_neutrals])
        # print(outputs)  # Tensor("lstm_model/Reshape:0", shape=(9600, 200), dtype=float32)

    # 进行全连接层
    with tf.compat.v1.variable_scope("correct"):
        weight = tf.compat.v1.Variable(tf.truncated_normal([self.num_neutrals, self.word_size], stddev=0.1))
        bias = tf.compat.v1.Variable(tf.truncated_normal([self.word_size], stddev=0.1))
        logits = tf.compat.v1.nn.xw_plus_b(outputs, weight, bias)
        # print(logits)  # Tensor("correct/xw_plus_b:0", shape=(9600, 6702), dtype=float32)

    # 进行损失计算
    with tf.compat.v1.name_scope("loss"):
        y_true_2 = tf.reshape(y_true, [-1, self.word_size])
        # print(y_true)  # Tensor("loss/Reshape:0", shape=(9600, 6702), dtype=float32)
        y_predict = tf.argmax(tf.argmax(logits, 1))
        loss = tf.reduce_mean(tf.compat.v1.nn.softmax_cross_entropy_with_logits(labels=y_true_2, logits=logits))

    with tf.compat.v1.name_scope("train_step"):
        train_op = tf.compat.v1.train.AdamOptimizer(self.learning_rate).minimize(loss)

    with tf.compat.v1.name_scope("acc"):
        equal = tf.equal(tf.argmax(y_true_2, 1), tf.argmax(logits, 1))
        accuracy = tf.reduce_mean(tf.cast(equal, tf.float32))

    init_op = tf.global_variables_initializer()

    saver = tf.compat.v1.train.Saver()

    if self.training is True:
        with tf.compat.v1.Session() as sess:
            sess.run(init_op)
            new_acc = 0
            t = 0
            for frequence in range(self.frequences):
                for epoch in range(self.epoch_size):
                    t += 1
                    x_batch = features[epoch * self.batch_size: (epoch+1) * self.batch_size]
                    y_batch = labels[epoch * self.batch_size: (epoch+1) * self.batch_size]
                    y_batch = tf.one_hot(y_batch, depth=self.word_size, axis=2)
                    sess.run(train_op, feed_dict={x: x_batch, y_true: y_batch.eval()})
                    if epoch % 50 == 0:
                        acc_val = sess.run(accuracy, feed_dict={x: x_batch, y_true: y_batch.eval()})
                        print("frequence:{}, epoch: {}, acc: {}".format(frequence, epoch, acc_val))
                        if acc_val > new_acc and frequence > 7:
                            new_acc = acc_val
                            saver.save(sess, save_path, global_step=t)
                            print("模型更好,保存成功!")
                        else:
                            saver.save(sess, save_path, global_step=t)
                            print("模型保存成功!")

    else:
        # data = input("请输入您的第一个字:")
        data = "敬"
        data_num = 0
        if data in word_to_id.keys():
            data_num = word_to_id[data]
        else:
            print("抱歉,您的字不在词库中!")
            exit(0)

        init_op = tf.compat.v1.global_variables_initializer()

        input_data = tf.placeholder(tf.int32, [1, None])

        with tf.compat.v1.Session() as sess:
            sess.run(init_op)

            saver.restore(sess, save_path)

            x = np.array([list(map(word_to_id.get, "G"))])

            pridict = sess.run(y_predict, feed_dict={input_data: x})

            if data:
                word = data
            else:
                word = id_to_word[pridict]

            poem = ""

            while word != "F":
                poem += word
                x = np.zeros(1, 1)
                x[0, 0] = word_to_id[word]
                predict = sess.run(y_predict, feed_dict={input_data: x})
                word = id_to_word[predict]

        print("生成的诗句为:", poem)

if name == ‘main’:
file_path = “/home/aistudio/poem/data/tang_all.txt”
save_path = “/home/aistudio/poem/model/”
# file_path = “./data/tang_all.txt”
poem = PoemRobot(file_path, save_path, False)

你可能感兴趣的:(深度学习)