《深度学习之TensorFlow》reading notes(3)—— MNIST手写数字识别之二

文章目录

  • 模型保存
  • 模型读取
  • 测试模型
    • 搭建测试模型
  • 使用模型
  • 模型可视化

本文是在上一篇文章 《深度学习之TensorFlow》reading notes(2)—— MNIST手写数字识别的基础上写的,主要内容是进一步实现对模型的测试、保存和模型读取使用。

模型保存

先上代码:

import...
#建立模型...
#配置参数...
saver = th.train.Saver()
model_path = "log/521model.ckpt"

with tf.Session as sess:
	...#这里是初始化和训练模型的过程
	#保存模型并打印
	save_path = saver.save(sess, model_path)
	print("Model saved in file: %s" % save_path)

只需要添加四句话就可以实现对训练模型的保存。好像也没啥需要解释的……

模型读取

import tensorflow as tf #导入tensorflow库
import pylab 
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784])  # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10])  # 0-9 数字=> 10 classes
# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

model_path = "log/521model.ckpt"

#读取模型
print("Starting 2nd session...")
saver = tf.train.Saver()
with tf.Session() as sess:
    # Initialize variables
#    sess.run(tf.global_variables_initializer())
    # Restore model weights from previously saved model
    saver.restore(sess, model_path)

读取模型参数前,一样需要将模型中的张量重新定义一遍,其实保存的模型中的参数值和模型结构,也就是W和b的值和向前传播的结构。用

saver.restore(sess, model_path)

就可以实现对模型的读取了,下面我们测试读取得到的模型,以及使用模型对一张手写图进行判断。

测试模型

搭建测试模型

首先,需要调用测试数据集中的数据,输入到模型中,看模型预测的结果与数据集的标签是否一致,如果一致则返回true,不一致则返回false。最后统计所有true的个数除以总数,即为模型准确率。

import tensorflow as tf  # 导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784])  # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10])  # 0-9 数字=> 10 classes
# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 构建测试模型
s1 = tf.matmul(x, W) + b
pred = tf.nn.softmax(s1)
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
model_path = "log/521model.ckpt"
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, model_path)
    Accaa = sess.run(accuracy,feed_dict={x: mnist.test.images, y: mnist.test.labels})
    print ("Accuracy:", Accaa)

首先,依旧是用模型求取pred值,也就是预测值。之后一句:

correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

其中tf.argmax()函数是用来检索最大值的,即得到最大值位置。由于y是onehot编码,所以,这个索引对应上即是预测正确,再用tf.equal()函数来确定其是否相等,就能得到正确的情况。
在通过下一句求平均,其实就是计数正确的个数,再除以总数。就可以得到准确率了。
运行结果:

runfile('E:/mnist_1/mnist_test.py', wdir='E:/mnist_1')
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Restoring parameters from log/521model.ckpt
Accuracy: 0.8587

可以看到,判断的准确率还可以,达到了85.87%,这其实是我迭代了50次的结果,25次通常维持在82%~83%左右。再迭代效果不明显了。

使用模型

使用模型在读取模型之后,选用test数据集中的数据:mnist.test.next_batch(num)其中num为调用多少个数据进行测试。sess.run中将测试数据作为输入,得到模型预测的output和预测概率pred。最后,将预测值、预测概率、标签值和图形都进行输出。

import tensorflow as tf #导入tensorflow库
import pylab 
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


tf.reset_default_graph()
# tf Graph Input
x = tf.placeholder(tf.float32, [None, 784])  # mnist data维度 28*28=784

# Set model weights
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))

# 构建模型
s1 = tf.matmul(x, W) + b
pred = tf.nn.softmax(s1)
output = tf.argmax(pred, 1)

model_path = "log/521model.ckpt"

num11 = 5
#读取模型
print("Starting 2nd session...")
saver = tf.train.Saver()
with tf.Session() as sess:
    # Initialize variables
#    sess.run(tf.global_variables_initializer())
    # Restore model weights from previously saved model
    saver.restore(sess, model_path)
    
 
    batch_xs, batch_ys = mnist.test.next_batch(num11)
    outputval,predv = sess.run([output,pred], feed_dict={x: batch_xs})
        
    for i in range(num11):
        print(outputval[i],predv[i,outputval[i]],batch_ys[i])
        im = batch_xs[i]
        im = im.reshape(-1,28)
        pylab.imshow(im)
        pylab.show()

输入结果:这里我取了两个0作为比较,可以看到第一个〇,预测概率为100%,第二个〇,预测概率为84%,从实际图中可以看出区别。
《深度学习之TensorFlow》reading notes(3)—— MNIST手写数字识别之二_第1张图片
《深度学习之TensorFlow》reading notes(3)—— MNIST手写数字识别之二_第2张图片

模型可视化

下次再说,玩一会自走棋去~

你可能感兴趣的:(deep,learning)