学习tensorflow的生成,在网上查到生成诗的例子,改了一下,将所有的代码全部放在一个文件中,进行运行。
#!-*- coding: utf-8 -*-
import math
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
# 禁用词,包含如下字符的唐诗将被忽略
disallowed_words = ['(', ')', '(', ')', '__', '《', '》', '【', '】', '[', ']']
max_len = 64
# 最小词频
min_word_frequency = 6
# 共训练多少个epoch
epochs = 20
# 训练的batch size
batch_size = 128
# 数据集路径
dataset_path = './poetry.txt'
random = True
# 加载数据集
lines = []
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f.readlines():
# 将冒号统一成相同格式
lines.append(line.replace(':', ':'))
# 数据集列表
poetries = []
# 逐行处理读取到的数据
for line in lines:
# 有且只能有一个冒号用来分割标题
if line.count(':') != 1:
continue
# 后半部分不能包含禁止词
__, last_part = line.split(':')
ignore_flag = False
for dis_word in disallowed_words:
if dis_word in last_part:
ignore_flag = True
break
if ignore_flag:
continue
# 长度不能超过最大长度
if len(last_part) > max_len - 2:
continue
# 为了使用tensorflow的text的分词器,这里将分割使用" "
poetries.append(last_part.replace('\n', '').replace("", ' ').strip())
# 分词器
tokenize = Tokenizer()
tokenize.fit_on_texts(poetries)
# 字汇
words = [word for word, count in tokenize.word_counts.items() if count > min_word_frequency]
words = ['[PAD]', '[UNK]', '[CLS]', '[SEP]'] + words
vocab_size = len(words)
# 创建字典
word_dict = dict(zip(words, range(vocab_size)))
index_word = dict((value, key) for key, value in word_dict.items())
# 获取总的训练数据
total = len(poetries)
best_model_path = 'my_bast_model.h5'
def __epoch_iter__():
# 是否随机混洗
if random:
np.random.shuffle(poetries)
# 迭代一个epoch,每次yield一个batch
for start in range(0, total, batch_size):
end = min(start + batch_size, total)
batch_data = []
# 逐一对古诗进行编码
for single_data in poetries[start:end]:
single_seq = [word_dict.get(item, word_dict['[UNK]']) for item in single_data if item != ' ']
single_seq = [word_dict['[CLS]']] + single_seq + [
word_dict['[SEP]']]
batch_data.append(single_seq)
# 填充为相同长度
batch_data = pad_sequences(batch_data, value=word_dict['[PAD]'])
# yield x,y
yield batch_data[:, :-1], tf.one_hot(batch_data[:, 1:], vocab_size)
del batch_data
def get_train():
while True:
yield from __epoch_iter__()
if os.path.exists(best_model_path):
# 加载模型
print(f"===== 加载模型 {best_model_path} ======")
model = load_model(best_model_path)
else:
# 构建模型
model = tf.keras.Sequential([
# 不定长度的输入
tf.keras.layers.Input((None,)),
# 词嵌入层
tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=128),
# 第一个LSTM层,返回序列作为下一层的输入
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
# 第二个LSTM层,返回序列作为下一层的输入
tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
# 对每一个时间点的输出都做softmax,预测下一个词的概率
tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(vocab_size, activation='softmax')),
])
# 查看模型结构
model.summary()
# 配置优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)
class Evaluate(tf.keras.callbacks.Callback):
"""
在每个epoch训练完成后,保留最优权重,并随机生成settings.SHOW_NUM首古诗展示
"""
def __init__(self):
super().__init__()
# 给loss赋一个较大的初始值
self.lowest = 1e10
def on_epoch_end(self, epoch, logs=None):
# 在每个epoch训练完成后调用
# 如果当前loss更低,就保存当前模型参数
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save(best_model_path)
# 随机生成几首古体诗测试,查看训练效果
token_ids = [word_dict['[CLS]']]
while len(token_ids) < max_len:
output = model(np.reshape(token_ids, (1, -1)))
probas = output.numpy()[0, -1, 3:]
del output
if np.argmax(probas) == 0:
break
probas = probas[1:]
# 按照出现概率,对所有token倒序排列
p_args = probas.argsort()[::-1][:100]
# 排列后的概率顺序
p = probas[p_args]
# 先对概率归一
p = p / sum(p)
# 再按照预测出的概率,随机选择一个词作为预测结果
target_index = np.random.choice(len(p), p=p)
target = p_args[target_index] + 4
# 保存
token_ids.append(target)
out_text = [index_word[item] for item in token_ids]
print(''.join(out_text[1:]).replace("。", "。\n"))
# 开始训练
model.fit(get_train(), epochs=epochs,
steps_per_epoch=int(math.floor(total / batch_size)),
callbacks=[Evaluate()])
运行结果如下:
===== 加载模型 my_bast_model.h5 ======
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) (None, None, 128) 456064
_________________________________________________________________
lstm (LSTM) (None, None, 128) 131584
_________________________________________________________________
lstm_1 (LSTM) (None, None, 128) 131584
_________________________________________________________________
time_distributed (TimeDistri (None, None, 3563) 459627
=================================================================
Total params: 1,178,859
Trainable params: 1,178,859
Non-trainable params: 0
_________________________________________________________________
Epoch 1/20
191/191 [==============================] - 63s 287ms/step - loss: 3.4426
朝朝不识老,与有君时便。
此年千地行,长与故阳客。
清上碧庭色,清溪已故夕。
无此不同心,时间泪有游。
Epoch 2/20
191/191 [==============================] - 55s 289ms/step - loss: 3.4255
山门与古远,尽夜生幽城。
世坐知书梦,相知道自闲。
人朝孤户下,红井草还间。
欲得君亲至,同归独不迷。
Epoch 3/20
191/191 [==============================] - 55s 287ms/step - loss: 3.4091
金公八子至,八国未难逢。
大火从相息,千门不见兵。
玉山流岸雪,疏响月鸣回。
却恨不伤到,无来一后游。
Epoch 4/20
191/191 [==============================] - 55s 286ms/step - loss: 3.4174
幽来大灵名,所不无同归。
风前月下水,水发自飞云。
路起心无住,归林一自斜。
烟尘千万落枝长,乱鸟初含玉辇烟。
Epoch 5/20
191/191 [==============================] - 55s 288ms/step - loss: 3.3755
何闻有所及,不是日光迟。
吴首向关路,又然清易伤。
雪连空北尽,孤月待天稀。
遥似西台去,多年谢病期。
Epoch 6/20
191/191 [==============================] - 55s 287ms/step - loss: 3.3861
霜声入风水,月日欲长秋。
天石烟云落,青窗月满扉。
不勤经我远,高节向云山。
远里浮山意,还今应奈亲。
Epoch 7/20
191/191 [==============================] - 55s 286ms/step - loss: 3.3795
妾处青云寺,空行满梦情。
闲风秋雨起,江雪海云寒。
花色前峰日,春前见马流。
无乡如思意,一日话吾亲。
Epoch 8/20
191/191 [==============================] - 59s 310ms/step - loss: 3.3618
雨出见南园,暮愁江雨稀。
一高人与客,千古在山时。
不觉随湖发,谁逢月下云。
长家无共过,吟去寄乡名。
Epoch 9/20
191/191 [==============================] - 56s 293ms/step - loss: 3.3708
江阳三下楚,风雨动征年。
草雨山阴雪,云霜出岸灯。
归思无客贵,从见泪离身。
Epoch 10/20
191/191 [==============================] - 57s 297ms/step - loss: 3.3599
荷霏香翠遍,风彩郁成尘。
御玉天霞入,残轩积雨林。
沙浓声落暗,海水滴生流。
不要休何处,空明梦复多。
古诗文件是直接使用网上的,也可以将github中的古诗集项目拿来用,这里不做演示。
项目源码https://gitee.com/MIEAPP/deep-learning/tree/master/example10