TensorFlow 实现RNN

TensorFlow 实现RNN

1. 导入所需包

import  tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data # 获取数据
import numpy as np

2. 获取数据

# 获取数据集
data_path = r'C:\Users\liev\Desktop\myproject\yin_test\MNIST_DATA_TensorFlow'
mnist = input_data.read_data_sets(data_path,one_hot=True)

# 获取数据集信息

print('训练集图片信息: ',np.array(mnist.train.images).shape)
print('训练集标签信息: ',np.array(mnist.train.labels).shape)
print('测试集图片信息: ',np.array(mnist.test.images).shape)
print('测试集标签信息: ',np.array(mnist.test.labels).shape)

输出:

训练集图片信息:  (55000, 784)
训练集标签信息:  (55000, 10)
测试集图片信息:  (10000, 784)
测试集标签信息:  (10000, 10)

3. 网络结构

RNN网络处理数据过程:

TensorFlow 实现RNN_第1张图片

代码实现:

class RNN:
    def __init__(self):
        self.in_w = tf.Variable(tf.truncated_normal([28,128], stddev=0.1))
        self.in_b = tf.Variable(tf.zeros([128]))

        self.out_w = tf.Variable(tf.truncated_normal([128,10], stddev=0.1))
        self.out_b = tf.Variable(tf.zeros([10]))

    def forward(self,in_x):
        y = tf.reshape(in_x,[-1,28]) # [100,784]->[128,28]
        # 激活:out:[128,128]
        y = tf.nn.leaky_relu(tf.add(tf.matmul(y,self.in_w),self.in_b))
        y = tf.reshape(y,[-1,28,128])

        # 实例化细胞
        cell = tf.nn.rnn_cell.LSTMCell(128)
        init_state = cell.zero_state(100,dtype=tf.float32)
        outputs,findstate = tf.nn.dynamic_rnn(cell,y,initial_state=init_state,time_major=False)
        # [ 100,28,128]
        out = tf.transpose(outputs,[1,0,2])[-1]
        out = tf.reshape(out,[-1,128])
        # 激活
        out = tf.nn.softmax(tf.add(tf.matmul(out,self.out_w),self.out_b))
        return out

class Net:
    def __init__(self):
        self.rnn = RNN()
        self.in_x = tf.placeholder(dtype=tf.float32, shape=[None,784])
        self.in_y = tf.placeholder(dtype=tf.float32, shape=[None,10])

        self.forward()
        self.backward()
    def forward(self):

        self.out = self.rnn.forward(self.in_x)

    def backward(self):

        self.loss = tf.reduce_mean((self.out-self.in_y)**2)
        self.opt = tf.train.AdamOptimizer().minimize(self.loss)

4. 训练网络

if __name__ == '__main__':
    net = Net()

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        saver = tf.train.Saver()
        for epoch in range(2000):
            xs,ys = mnist.train.next_batch(100)
            loss,_= sess.run([net.loss,net.opt],feed_dict={net.in_x: xs, net.in_y:ys})

            if epoch % 200 == 0:
                saver.save(sess,r'C:\Users\liev\Desktop\myproject\yin_test\log_TensorFlow\RNN.ckpt')
                print('loss: ',loss)
                test_xs,test_ys = mnist.test.next_batch(100)
                out = sess.run([net.out],feed_dict={net.in_x: test_xs, net.in_y:test_ys})
                out_y = np.reshape(out,[100,10])
                accuracy = np.mean(np.argmax(out_y,axis=1) == np.argmax(test_ys,axis=1))
                print('accuracy: ',accuracy)

输出:

loss:  0.09001497
accuracy:  0.1
loss:  0.021020493
accuracy:  0.9
loss:  0.0061496254
accuracy:  0.99
loss:  0.000798139
accuracy:  0.98

损失图:

TensorFlow 实现RNN_第2张图片

5. 测试网络:

if __name__ == '__main__':
    net = Net()
    accuracy_count = []
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        saver = tf.train.Saver()
        saver.restore(sess,r'C:\Users\liev\Desktop\myproject\yin_test\log_TensorFlow\RNN.ckpt')
        for epoch in range(2000):
            test_xs,test_ys = mnist.test.next_batch(100)
            out = sess.run([net.out],feed_dict={net.in_x: test_xs, net.in_y:test_ys})
            out_y = np.reshape(out,[100,10])
            accuracy = np.mean(np.argmax(out_y,axis=1) == np.argmax(test_ys,axis=1))
            accuracy_count.append(accuracy)
            print('accuracy: ',accuracy)
        print('平均准确度: ',sum(accuracy_count)/len(accuracy_count))
        plt.figure('RNN_Accuracy')
        plt.plot(accuracy_count,'o',label='accuracy')
        plt.legend()
        plt.show()

输出:

accuracy:  1.0
accuracy:  0.99
accuracy:  0.97
平均准确度:  0.9755000000000169

TensorFlow 实现RNN_第3张图片

你可能感兴趣的:(RNN)