RNN识别mnist手写数字数据集

# -*- coding: utf-8 -*-


# @Time    : 2018/12/24 11:04
# @Author  : WenZhao
# @Email   : [email protected]
# @File    : mnistRnn-1.py
# @Software: PyCharm
'''
    RNN识别mnist手写数字数据集
'''

import tensorflow as tf
import numpy as np
# 下载数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("./data/MNIST_data/",one_hot=True)

learning_rate=0.001
batch_size=128


n_input=28
n_steps=28
n_hidden=128
n_classes=10

x=tf.placeholder(tf.float32,[None,n_steps,n_input])
y=tf.placeholder(tf.float32,[None,n_classes])


output,_=tf.nn.dynamic_rnn(
    tf.contrib.rnn.GRUCell(n_hidden),
    x,
    dtype=tf.float32,
    sequence_length=batch_size*[n_input]
)

index=tf.range(0,batch_size)*n_steps+(n_input-1)
flat=tf.reshape(output,[-1,int(output.get_shape()[2])])
last=tf.gather(flat,index)

num_classes=int(y.get_shape()[1])
weight=tf.Variable(tf.truncated_normal([n_hidden,num_classes],stddev=0.01))
bias=tf.Variable(tf.constant(0.1,shape=[num_classes]))
prediction=tf.nn.softmax(tf.matmul(last,weight)+bias)

cross_entropy=-tf.reduce_sum(y*tf.log(prediction))

optimizer=tf.train.AdamOptimizer(learning_rate,beta1=0.5)
grads=optimizer.compute_gradients(cross_entropy)
for i,(g,v) in enumerate(grads):
    if g is not None:
        grads[i]=(tf.clip_by_norm(g,5),v)

train_op=optimizer.apply_gradients(grads)

correct_pred=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
sess=tf.Session()

init=tf.global_variables_initializer()

sess.run(init)
for step in range(1300):
    batch_x,batch_y=mnist.train.next_batch(batch_size)
    batch_x=batch_x.reshape((batch_size,n_steps,n_input))
    sess.run(train_op,feed_dict={x:batch_x,y:batch_y})
    if step%50==0:
        acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})
        loss=sess.run(cross_entropy,feed_dict={x:batch_x,y:batch_y})

        print("Iter:"+str(step)+",Minibatch Loss="+str(loss)+",Train Accuracy:"+str(acc))


print("OK")

test_x=mnist.test.images
test_x=test_x.reshape(-1,n_steps,n_input)

test_y=mnist.test.labels

acc=sess.run(accuracy,feed_dict={x:test_x[:128],y:test_y[:128]})

print(acc)

 

你可能感兴趣的:(深度学习)