输入格式:batch_size*784改成batch_size*28*28,把784个像素改成了28个行的序列,每一行的内容是一行像素的28个像素灰度数值。
让神经网络逐行扫描一个手写字体图案,总结各行特征,通过时间序列串联起来,最终得出结论。
网络定义:单独定义一个获取单元的函数,便于在MultiRNNCell中调用,创建多层LSTM网络
def get_a_cell(i):
lstm_cell =rnn.BasicLSTMCell(num_units=HIDDEN_CELL, forget_bias = 1.0, state_is_tuple = True, name = 'layer_%s'%i)
print(type(lstm_cell))
dropout_wrapped = rnn.DropoutWrapper(cell = lstm_cell, input_keep_prob = 1.0, output_keep_prob = keep_prob)
return dropout_wrapped
multi_lstm = rnn.MultiRNNCell(cells = [get_a_cell(i) for i in range(LSTM_LAYER)],
state_is_tuple=True)#tf.nn.rnn_cell.MultiRNNCell
简单说一下其他细节和坑:RNN有不同的运行方法,最简单的是用dynamic,直接吃结果。
outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
final_out = outputs[:,-1,:]
也可以写个循环手动运行seq_num次,得到最终结果(下面两种形式,反正本质都是调用__call__):
outputs = list()
state = init_state
with tf.variable_scope('RNN'):
for timestep in range(STEP_SIZE):
# (cell_output, state) = multi_lstm(tf_x_reshaped[:,timestep,:],state)
(cell_output, state) = multi_lstm.call(tf_x_reshaped[:,timestep,:],state)
outputs.append(cell_output)
# print('cell_output:', cell_output)
h_state = outputs[-1]
batch_size对RNN是有影响的,因为LSTM有0号状态需要初始化的,这个是和batch_size挂钩的。所以最好把batch_size用placeholder输入,而不是常量。
这里测试集传入数据的batch_size想比训练传入数据的batch_size大一些,就会报错!当然,如果我懒得改代码,也可以用个小循环多次取小数据集的结果,最后取平均。
init_state = multi_lstm.zero_state(batch_size = BATCH_SIZE, dtype = tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.summary.FileWriter('graph', graph=sess.graph)
for i in range(2000):
x,y = MNIST.train.next_batch(BATCH_SIZE)
_, loss_,outputs_, state_, right_predictions_num_ = \
sess.run([train_op, cross_entropy,outputs, state,right_predictions_num], {tf_x:x, tf_y:y, keep_prob:1.0})
print('loss:', loss_)
# print('right_predictions_num_:', right_predictions_num_)
if i % 200 == 0:
# tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp: Dimensions of inputs should match: shape[0] = [1000, 28] vs.shape[1] = [100, 256]
# test_x, test_y = MNIST.test.next_batch(BATCH_SIZE * 10)
total_accuracy = 0.
total_test_batch = 10
for j in range(total_test_batch):
test_x, test_y = MNIST.test.next_batch(BATCH_SIZE)
accuracy_ = sess.run([accuracy], {tf_x:test_x, tf_y:test_y, keep_prob:1.0})
total_accuracy += accuracy_[0]
total_accuracy = total_accuracy / total_test_batch
print('total_accuracy:', total_accuracy)
本例实现代码:
https://github.com/huqinwei/tensorflow_demo/blob/master/lstm_mnist/multi_lstm.py
lstm多层结构state的存在形式:
https://blog.csdn.net/huqinweI987/article/details/83148239