继上一篇用简单的卷积神经网络做mnist分类之后,本篇文章采用RNN替换CNN写了一个mnist分类实例。实例中包含两个文件:
train.py:数据加载和训练代码。
# coding=utf-8
import tensorflow as tf
import os
import model
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('dataset/', one_hot=True)
tf.app.flags.DEFINE_integer('sequence_step', 28, 'step of input sequence')
tf.app.flags.DEFINE_integer('vector_size', 28, 'length of input vector')
tf.app.flags.DEFINE_integer('num_classes', 10, 'num of class')
tf.app.flags.DEFINE_float('lr', 0.001, 'learning rate')
tf.app.flags.DEFINE_integer('batch_size', 32, 'batch size')
tf.app.flags.DEFINE_integer('epochs', 50, 'num of epoch')
tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints')
tf.app.flags.DEFINE_boolean('continue_training', False, 'continue')
FLAGS = tf.app.flags.FLAGS
def main(_):
input = tf.placeholder(dtype=tf.float32, shape=[None, FLAGS.sequence_step * FLAGS.vector_size])
output = tf.placeholder(dtype=tf.int32, shape=[None, 10])
# control GPU resource utilization
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# network
logits = model.build_rnn(input, FLAGS.sequence_step, FLAGS.vector_size, FLAGS.batch_size)
# loss
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))
# optimiter
train_op = tf.train.AdamOptimizer().minimize(cross_entropy)
# evaluation
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with sess.as_default():
# initial
saver = tf.train.Saver(max_to_keep=1000)
sess.run(tf.global_variables_initializer())
# Restore weights file
if FLAGS.continue_training:
saver.restore(sess, FLAGS.checkpoints)
# begin train
for epoch in range(FLAGS.epochs):
for k in range(int(mnist.train.num_examples / FLAGS.batch_size)):
train_image, train_label = mnist.train.next_batch(FLAGS.batch_size)
train_image = train_image / 255.0
_, network, loss, acc = sess.run([train_op, logits, cross_entropy, accuracy], feed_dict={input: train_image, output: train_label})
print('loss : %f accuracy : %f' % (loss, acc))
test_image = mnist.test.images / 255.0
#test_image = test_image.reshape([-1, FLAGS.sequence_step, FLAGS.vector_size])
test_label = mnist.test.labels
indices = np.arange(len(test_image))
np.random.shuffle(indices)
test_index = indices[0:FLAGS.batch_size]
print('精确率:', accuracy.eval({input: test_image[test_index], output: test_label[test_index]}))
# Create directories if needed
if not os.path.isdir("checkpoints"):
os.makedirs("checkpoints")
saver.save(sess, "%s/model.ckpt" % ("checkpoints"))
if __name__ == '__main__':
tf.app.run()
训练部分基本和上一篇CNN分类相同。
model.py:搭建了一个简单的循环神经网络,RNN的输入和CNN略有不同,针对28×28的图片,CNN采用[batch, height, width, channel]的tensor形式; RNN网络的输入将28行看成28个时间序列,每一个时间序列的输入是[1,28]。本例子的lstm单元设置128个节点。
import tensorflow as tf
import tensorflow.contrib.rnn
def weight_variable(shape, stddev=0.1):
initial = tf.truncated_normal(shape=shape, stddev=stddev)
return tf.Variable(initial)
def bias_variable(shape, alpha=0.1):
initial = tf.constant(shape=shape, value=alpha)
return tf.Variable(initial)
def build_rnn(inputs, sequence_size, vector_size, batch_size):
weights = {
# shape (28, 128)
'in': tf.Variable(tf.random_normal([vector_size, 128])),
# shape (128, 10)
'out': tf.Variable(tf.random_normal([128, 10]))
}
biases = {
# shape (128, )
'in': tf.Variable(tf.constant(0.1, shape=[128, ])),
# shape (10, )
'out': tf.Variable(tf.constant(0.1, shape=[10, ]))
}
inputs = tf.reshape(inputs, [-1, vector_size])
inputs = tf.matmul(inputs, weights['in']) + biases['in']
inputs = tf.reshape(inputs, [-1, sequence_size, 128])
# lstm cell.
lstm_cell = tf.contrib.rnn.BasicLSTMCell(128, forget_bias=1.0, state_is_tuple=True)
# init
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) # 初始化全零 state
# Implement rnn
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, inputs, initial_state=init_state, time_major=False)
results = tf.matmul(final_state[1], weights['out']) + biases['out']
return results
运行结果:
loss : 0.000472 accuracy : 1.000000
loss : 0.000217 accuracy : 1.000000
loss : 0.000051 accuracy : 1.000000
loss : 0.000253 accuracy : 1.000000
loss : 0.000483 accuracy : 1.000000
精确率: 0.96875
如有相关问题,欢迎留言讨论。