import tensorflow as tf
import pylab
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("../MNIST_data/",one_hot=True)
结果:
Extracting ../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../MNIST_data/t10k-labels-idx1-ubyte.gz
2、建立模型(本文是单层双向RNN网络模型)
import numpy as np
n_input=28
n_steps=28
n_hidden=128
n_classes=10
tf.reset_default_graph()
x=tf.placeholder("float",[None,n_steps,n_input])
y=tf.placeholder("float",[None,n_classes])
model_path="logo/RNNmodel.ckpt"
#BasicLSTMCell细胞,静态单层RNN
# x1=tf.unstack(x,n_steps,1)
# lstm_cell=tf.contrib.rnn.BasicLSTMCell(n_hidden,forget_bias=1.0)
# outputs,states=tf.contrib.rnn.static_rnn(lstm_cell,x1,dtype=tf.float32)
# print(len(outputs),outputs[0],len(states),states)
'''
28 Tensor("rnn/rnn/basic_lstm_cell/mul_2:0", shape=(?, 128), dtype=float32)
2 LSTMStateTuple(c=,
h=)
'''
# GRU细胞,动态单层RNN
# gru=tf.contrib.rnn.GRUCell(n_hidden)
# outputs,states=tf.nn.dynamic_rnn(gru,x,dtype=tf.float32)
# print(outputs.shape,states.shape) #(?, 28, 128) (?, 128)
# outputs=tf.transpose(outputs,[1,0,2])
# print(outputs.shape,states.shape) #(28, ?, 128) (?, 128)
# pred=tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn=None)
# print(outputs[-1].shape) #(?, 128)
#静态多层LSTM(MultiRNNCell细胞)
# stacked_rnn=[]
# for i in range(3):
# stacked_rnn.append(tf.contrib.rnn.LSTMCell(n_hidden))
# mcell=tf.contrib.rnn.MultiRNNCell(stacked_rnn)
# x1=tf.unstack(x,n_steps,1)
# outputs,states=tf.contrib.rnn.static_rnn(mcell,x1,dtype=tf.float32)
# print(np.shape(outputs),np.shape(states),states[-1])
'''
(28,)
(3, 2)
LSTMStateTuple(c=,
h=)
'''
#静态多层RNN-LSTM(MultiRNNCell细胞)
# gru=tf.contrib.rnn.GRUCell(n_hidden*4)
# lstm_cell=tf.contrib.rnn.LSTMCell(n_hidden)
# mcell=tf.contrib.rnn.MultiRNNCell([lstm_cell,gru])
# outputs,states = tf.nn.dynamic_rnn(mcell,x,dtype=tf.float32)#(?, 28, 256)
# outputs = tf.transpose(outputs, [1, 0, 2])#(28, ?, 256) 28个时序,取最后一个时序outputs[-1]=(?,256)
# print(outputs,states)
'''
Tensor("transpose_1:0", shape=(28, ?, 512), dtype=float32)
(LSTMStateTuple(c=,
h=),
)
'''
#单层静态双向RNN(tf.nn.bidirectional_dynamic_rnn)
# lstm_fw_cell=tf.contrib.rnn.BasicLSTMCell(n_hidden,forget_bias=1.0)
# lstm_bw_cell=tf.contrib.rnn.BasicLSTMCell(n_hidden,forget_bias=1.0)
# outputs,states=tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell,lstm_bw_cell,x,dtype=tf.float32)
# print(len(outputs),outputs)
# print(states)
# '''
# 2
# (,
# )
# (LSTMStateTuple(c=,
# h=),
# LSTMStateTuple(c=,
# h=))
# '''
# print("*************************************")
# print(outputs[0].shape,outputs[1].shape)#(?, 28, 128) (?, 28, 128)
# outputs=tf.concat(outputs,2)
# outputs=tf.transpose(outputs,[1,0,2])
x1 = tf.unstack(x, n_steps, 1)
lstm_fw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
# 反向cell
lstm_bw_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
outputs,b,c= tf.contrib.rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x1,
dtype=tf.float32)
print(outputs[0].shape,len(outputs))
pred = tf.contrib.layers.fully_connected(outputs[-1],n_classes,activation_fn = None)
saver = tf.train.Saver()
learing_rate=0.001
training_iters=100000
batch_size=128
display_step=10
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer=tf.train.AdamOptimizer(learning_rate=learing_rate).minimize(cost)
correct_pred=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
结果:
Iter 85760, Minibatch Loss= 0.102439, Training Accuracy= 0.96875 Iter 87040, Minibatch Loss= 0.154755, Training Accuracy= 0.94531 Iter 88320, Minibatch Loss= 0.097860, Training Accuracy= 0.96875 Iter 89600, Minibatch Loss= 0.149429, Training Accuracy= 0.95312 Iter 90880, Minibatch Loss= 0.093206, Training Accuracy= 0.96875 Iter 92160, Minibatch Loss= 0.100381, Training Accuracy= 0.97656 Iter 93440, Minibatch Loss= 0.126264, Training Accuracy= 0.96875 Iter 94720, Minibatch Loss= 0.096257, Training Accuracy= 0.96094 Iter 96000, Minibatch Loss= 0.177734, Training Accuracy= 0.94531 Iter 97280, Minibatch Loss= 0.181315, Training Accuracy= 0.94531 Iter 98560, Minibatch Loss= 0.197556, Training Accuracy= 0.92969 Iter 99840, Minibatch Loss= 0.125003, Training Accuracy= 0.96094 Finished! Testing Accuracy: 0.96875 Model saved in file: logo/RNNmodel.ckpt
4、训练模型(保存模型)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step=1
while step*batch_size < training_iters:
batch_x,batch_y=mnist.train.next_batch(batch_size)
batch_x=batch_x.reshape((batch_size,n_steps,n_input))
sess.run(optimizer,feed_dict={x:batch_x,y:batch_y})
if step % display_step==0:
acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})
loss=sess.run(cost,feed_dict={x:batch_x,y:batch_y})
print ("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
"{:.6f}".format(loss) + ", Training Accuracy= " + \
"{:.5f}".format(acc))
step += 1
print (" Finished!")
# 计算准确率 for 128 mnist test images
test_len = 128
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print ("Testing Accuracy:", \
sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
# Save model weights to disk
save_path = saver.save(sess, model_path)
print("Model saved in file: %s" % save_path)
5、读取模型与测试结果
#读取模型
print("Starting 2nd session...")
with tf.Session() as sess:
# Initialize variables
sess.run(tf.global_variables_initializer())
# Restore model weights from previously saved model
saver.restore(sess, model_path)
# # 测试 model
# correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# # 计算准确率
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
output = tf.argmax(pred, 1)
batch_xs, batch_ys = mnist.train.next_batch(2)
print(batch_xs.shape,batch_ys.shape)
batch_xs = batch_xs.reshape((2, n_steps, n_input))
outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs })
print(outputval)
im = batch_xs[0]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
im = batch_xs[1]
im = im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
结果:
Starting 2nd session...
INFO:tensorflow:Restoring parameters from logo/RNNmodel.ckpt
(2, 784) (2, 10)
[6 8]