基于卷积神经网络LeNet-5模型的mnist手写数字识别

基于卷积神经网络LeNet-5模型的mnist手写数字识别

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

import os
'''
tensorflow中os.environ["TF_CPP_MIN_LOG_LEVEL"]的值的含义
log信息共有四个等级,按重要性递增为:
INFO(通知)= 10:
            LEARN_RATE = LEARN_RATE * LEARN_DAMP
            train_step = tf.train.GradientDescentOptimizer(LEARN_RATE).minimize(loss)

        for batch in range(N_BATCH):
            batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE) # 依次往后取训练集中的一个批次数据
            sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 0.8})

        # 验证集的数字识别正确率
        validation_acc = sess.run(accuracy, feed_dict={x: mnist.validation.images, y: mnist.validation.labels,
                                                       keep_prob: 1.0})
        # 测试集的数字识别正确率
        test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0})

        print('Iter' + str(epoch) + ',Validation Accuracy' + str(validation_acc) + ',Testing Accuracy' + str(test_acc))
    #保存模型
    saver.save(sess, 'NET/my_mnist_LeNet-5.ckpt')

测试结果如下所示:正确率达到99.4%左右。
基于卷积神经网络LeNet-5模型的mnist手写数字识别_第1张图片
接下来调用保存的模型进行简单对比测试:

import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
'''
tensorflow中os.environ["TF_CPP_MIN_LOG_LEVEL"]的值的含义
log信息共有四个等级,按重要性递增为:
INFO(通知)

对比测试结果如下:
基于卷积神经网络LeNet-5模型的mnist手写数字识别_第2张图片

你可能感兴趣的:(深度学习)