LSTM简单的例子

LSTM生成评论的例子

使用前10个字推出后面的1个字

import numpy
from keras.models import Sequential
from keras.layers import Dense, Dropout, LSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils

# 读取txt文件
filename = 'comments.txt'
with open(filename, 'r', encoding='utf-8') as f:
    raw_text = f.read().lower()

# 创建文字和对应数字字典
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
int_to_char = dict((i, c) for i, c in enumerate(chars))

# 对加载数据做总结
n_chars = len(raw_text)
n_vocab = len(chars)
print("总的文字数:", n_chars)
print("总的文字类别:", n_vocab)

# 生成数据集,转化为输入向量和输出向量
seq_length = 10
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
    seq_in = raw_text[i: i + seq_length]	# 输入前10个字
    seq_out = raw_text[i + seq_length]		# 输出后 1个字
    dataX.append([char_to_int[char] for char in seq_in]) 	# 将字转化成对应的序号
    dataY.append(char_to_int[seq_out])
n_patterns = len(dataX) 					# 数据集的大小
print("Total Patterns: ", n_patterns)

# 将X重新转化为[samples, time_steps, features]形状
X = numpy.reshape(dataX, (n_patterns, seq_length, 1))
X = X / n_vocab
y = np_utils.to_categorical(dataY)

# 定义LSTM
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))   # 输入维度(10, 1)
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')

filepath = "./LSTM/weights-improvement.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]

# 模型训练
epochs = 1000
model.fit(X, y, epochs=epochs, batch_size=128, callbacks=callbacks_list)

## 模型预测 ====
input = '杭州西湖天下闻名,西'
pattern = [char_to_int[value] for value in input]
print("输入:")
print(''.join([int_to_char[value] for value in pattern]))
print("输出:")
for i in range(1000):
    x = numpy.reshape(pattern, (1, len(pattern), 1))
    x = x / float(n_vocab)
    prediction = model.predict(x, verbose=0)
    index = numpy.argmax(prediction)
    result = int_to_char[index]
    print(result, end='')
    seq_in = [int_to_char[value] for value in pattern]
    pattern.append(index)
    pattern = pattern[1: len(pattern)] 			   # 这里的pattern永远都是10个字
print("\n生成完毕。")

转载自

  • https://www.cnblogs.com/jclian91/p/9863848.html, 非常容易理解的一个例子,生成西湖评论
    • 数据也可以使用这里面的爬虫实现

你可能感兴趣的:(学习笔记)