手把手教你用tensorflow做无人驾驶(二)-LSTM应用实战

通过这篇博客,你可学到怎么在tensorflow环境下搭建LSTM网络(这里包括单层与多层),同时使用matplotlib模块画图,通过训练完以后,把网络保存下来,以后再次打开网络就不需要再次训练网络,直接用即可。这里我会演示保存下来的网络怎么恢复以及使用保存下来的网络进行测试,就不要训练了。首先建立一个LSTM.py,代码如下:

from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import rnn

import matplotlib.pyplot as plt

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)    # MNIST_data指的是存放数据的文件夹路径,one_hot=True 为采用one_hot的编码方式编码标签

tf.reset_default_graph()

##//////新增打印///////////
plotdata={"batchsize":[],"loss":[]}

#///////////////////////////////


# 训练参数
learning_rate = 0.001 # 学习率
training_steps = 10000 # 总迭代次数
batch_size = 128 # 批量大小
display_step = 200

# 网络参数
num_input = 28 # MNIST数据集图片: 28*28
timesteps = 28 # timesteps
num_hidden = 128 # 隐藏层神经元数
num_classes = 10 # MNIST 数据集类别数 (0-9 digits)

# 定义输入
X = tf.placeholder("float", [None, timesteps, num_input],name="input_x")
Y = tf.placeholder("float", [None, num_classes],name="input_y")

# 定义权重和偏置
# weights矩阵[128, 10]
weights = {
    'out': tf.Variable(tf.random_normal([num_hidden, num_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([num_classes]))
}

# 定义LSTM网络
def LSTM(x, weights, biases):

    # Prepare data shape to match `rnn` function requirements
    # 输入数据x的shape: (batch_size, timesteps, n_input)
    # 需要的shape: 按 timesteps 切片,得到 timesteps 个 (batch_size, n_input)

    # 对x进行切分
    # tf.unstack(value,num=None,axis=0,name='unstack')
    # value:要进行分割的tensor
    # axis:整数,打算进行切分的维度
    # num:整数,axis(打算切分)维度的长度
    x = tf.unstack(x, timesteps, 1)

    # 定义一个lstm cell,即上面图示LSTM中的A
    # n_hidden表示神经元的个数,forget_bias就是LSTM们的忘记系数,如果等于1,就是不会忘记任何信息。如果等于0,就都忘记。
    lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)


#    #/////////////添加多层///////////
#    lstm_cell_1 = tf.contrib.rnn.BasicLSTMCell(num_units=512)
#    lstm_cell_2 = tf.contrib.rnn.BasicLSTMCell(num_units=256)
#    lstm_cell_3 = tf.contrib.rnn.BasicLSTMCell(num_units=128)
#    lstm_cell=tf.contrib.rnn.MultiRNNCell(cells=[lstm_cell_1,lstm_cell_2,lstm_cell_3])
    #////////////////////////////////////

    # 得到 lstm cell 输出
    # 输出output和states
    # outputs是一个长度为T的列表,通过outputs[-1]取出最后的输出
    # state是最后的状态
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

    # 线性激活
    # 矩阵乘法
    return tf.matmul(outputs[-1], weights['out']) + biases['out']


logits = LSTM(X, weights, biases)
prediction = tf.nn.softmax(logits)

##////////,这里值是进行了模型的保存,这里保存的目的是为了进行加载并对输入的数据进行测试,并且不需要重建整个网络。

tf.add_to_collection('prediction',prediction)

#//////////////////////////


# 定义损失函数和优化器
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
    logits=logits, labels=Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)

# 模型评估(with test logits, for dropout to be disabled)
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# 初始化全局变量
init = tf.global_variables_initializer()


#////新增程序//////保存数据首先创建一个对象

saver=tf.train.Saver()

#///////////////////




# Start training
with tf.Session() as sess:

    # Run the initializer
    sess.run(init)

    for step in range(1, training_steps+1):
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, timesteps, num_input))
        # Run optimization op (backprop)
        sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
        if step % display_step == 0 or step == 1:
            # Calculate batch loss and accuracy
            loss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,
                                                                 Y: batch_y})
            print("Step " + str(step) + ", Minibatch Loss= " + \
                  "{:.4f}".format(loss) + ", Training Accuracy= " + \
                  "{:.3f}".format(acc))
            model_path = "./model/my_model"
            save_path = saver.save(sess, model_path)
            plotdata["batchsize"].append(step)
            plotdata["loss"].append(loss)
     #////////新增画图///////////       
    plt.plot(plotdata["batchsize"],plotdata["loss"],'b--')
    plt.xlabel('minibatch number')
    plt.ylabel('loss')
    plt.title('Training loss')
    plt.show()
    #//////////////////////////
    print("Optimization Finished!")

    # Calculate accuracy for 128 mnist test images
    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1, timesteps, num_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
        sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

1.上述代码需要注意的地方,如果想要多层网络把这个注释去掉,

#    #/////////////添加多层///////////
#    lstm_cell_1 = tf.contrib.rnn.BasicLSTMCell(num_units=512)
#    lstm_cell_2 = tf.contrib.rnn.BasicLSTMCell(num_units=256)
#    lstm_cell_3 = tf.contrib.rnn.BasicLSTMCell(num_units=128)
#    lstm_cell=tf.contrib.rnn.MultiRNNCell(cells=[lstm_cell_1,lstm_cell_2,lstm_cell_3])
    #////////////////////////////////////
把lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)

注释掉即可。

2.打印在代码中定义了存储变量:

##//////新增打印///////////
plotdata={"batchsize":[],"loss":[]}

#///////////////////////////////

3.想要保存模型首先创建一个对象
     


#////新增程序//////保存数据首先创建一个对象

saver=tf.train.Saver()

#///////////////////

 然后通过:

 model_path = "./model/my_model"
 save_path = saver.save(sess, model_path)

这两行代码把模型保存到当前程序所在目录model文件夹下,名字为my_model.

手把手教你用tensorflow做无人驾驶(二)-LSTM应用实战_第1张图片

为了在验证阶段使用模块,程序代码中增加了:

##////////,这里值是进行了模型的保存,这里保存的目的是为了进行加载并对输入的数据进行测试,并且不需要重建整个网络。

tf.add_to_collection('prediction',prediction)

#//////////////////////////

现在把模型保存好了,开始用训练完保存好的模型进行测试,建立一个LSTM_test.py,代码如下:

import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pylab
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

n_input = 28
n_steps = 28
n_classes = 10

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./model/my_model.meta')
    new_saver.restore(sess, './model/my_model')

    graph = tf.get_default_graph()
    predict = tf.get_collection('prediction')[0]
    input_x = graph.get_operation_by_name("input_x").outputs[0]
    
#    im = Image.open('./MNIST_data/3.jpg')
#    
#    pylab.imshow(im)
#    plt.show()
#    im = im.convert('L')
#    tv = list(im.getdata()) 
#    tva = [(255-x)*1.0/255.0 for x in tv]
#    im=np.array(im)
#    im1=im.reshape((1, n_steps, n_input))
#
#    print(im.shape)
#   
#    res = sess.run(predict, feed_dict={input_x: im1 })
#    print("predict class ",str(sess.run(tf.argmax(res, 1))))

    x = mnist.test.images[1].reshape((1, n_steps, n_input))
    y = mnist.test.labels[1].reshape(-1, n_classes)  # 转为one-hot形式
    x1=x.reshape(-1,28)   
    pylab.imshow(x1)
    plt.show()
    

    res = sess.run(predict, feed_dict={input_x: x })

    print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \
          ", predict class ",str(sess.run(tf.argmax(res, 1))), \
          ", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(res, 1))))
          )

下面这两行代码就是从保存好的模型中恢复参数:

new_saver = tf.train.import_meta_graph('./model/my_model.meta')
new_saver.restore(sess, './model/my_model')

通过下面这两行进行计算与输入的提取

predict = tf.get_collection('prediction')[0]
input_x = graph.get_operation_by_name("input_x").outputs[0]

下面代码进行图片数据处理,这里,也可以看看输入图片是什么样子的

x = mnist.test.images[1].reshape((1, n_steps, n_input))
y = mnist.test.labels[1].reshape(-1, n_classes)  # 转为one-hot形式
x1=x.reshape(-1,28)   
pylab.imshow(x1)
plt.show()

手把手教你用tensorflow做无人驾驶(二)-LSTM应用实战_第2张图片

下面这段代码就是预测与给定比较

 res = sess.run(predict, feed_dict={input_x: x })

    print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \
          ", predict class ",str(sess.run(tf.argmax(res, 1))), \
          ", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(res, 1))))
          )

结果如下:

这样结果就是预测与给定标注是符合的。

    

你可能感兴趣的:(tensorflow,LSTM)