深度学习实例之基于mnist的手写数字识别

本文主要是介绍基于mnist数据集的手写数字识别.

一 数据集

    mnist 数据集:包含 7 万张黑底白字手写数字图片, 其中 55000 张为训练集,5000 张为验证集, 10000 张为测试集。每张图片大小为 28*28 像素,图片中纯黑色像素值为 0, 纯白色像素值为 1。数据集的标签是长度为 10 的一维数组,数组中每个元素索引号表示对应数字出现的概率。在将 mnist 数据集作为输入喂入神经网络时,需先将数据集中每张图片变为长度784 一维数组,将该数组作为神经网络输入特征喂入神经网络。

    1. 使用tensorflow提供的数据集mnist,具体的加载方法为:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data(data_path,one_hot=True)

    2. 数据集分为train,validation,test三个数据集.

        ① 返回数据集train样本数   mnist.train.num_examples

        ② 返回数据集validation样本数 mnist.validation.num_examples

        ③ 返回数据集test样本数  mnist.test.num_examples

    3. 使用mnist.train.images返回train数据集中的所有图片的像素值

    4. 使用mnist.train.labels返回train数据集中的所有图片的标签

    5. 使用mnist.train.next_batch()将数据输入神经网络

二 前向计算(得到预测值)

    废话不说了,直接看代码.(mnist_forward.py)

# _*_coding:utf-8_*_

import tensorflow as tf

input_node = 784
output_node = 10
layer1_node = 500


def get_weight(shape, regularizer):
    # 表示要求产生的数据服从正态分布, 并且每个值与均值之间的差值均小于两倍的标准差
    w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
    # 判断是否进行正则化操作
    if regularizer is not None:
        tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w


def get_bias(shape):
    b = tf.Variable(tf.zeros(shape))
    return b


def forward(x, regulaizer):
    # 输入层到Layer1层
    w1 = get_weight(shape=[input_node, layer1_node], regularizer=regulaizer)
    b1 = get_bias(shape=[layer1_node])
    y1 = tf.nn.relu(tf.add(tf.matmul(x, w1), b1))
    # 从Layer1层到输出层
    w2 = get_weight(shape=[layer1_node, output_node], regularizer=regulaizer)
    b2 = get_bias(shape=[output_node])
    y = tf.add(tf.matmul(y1, w2), b2)
    return y

三 反向计算(参数更新)

    ( mnist_backward.py)
# _*_coding:utf-8_*_

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

batch_size = 200  # 每轮训练的图片数量
learning_rate_base = 0.1  # 初始学习率
learning_rate_decay = 0.99  # 学习率衰减率
regularizer = 0.0001  # 正则化系数
total_steps = 50000  # 训练轮数
moving_average_decay = 0.99
model_save_path = './model/'
model_name = 'mnist_model'


def backward(mnist):
    # 设置x,y
    x = tf.placeholder(tf.float32, [None, mnist_forward.input_node])
    y = tf.placeholder(tf.float32, [None, mnist_forward.output_node])
    y_hat = mnist_forward.forward(x, regularizer)  # 获取forward的返回值
    global_step = tf.Variable(0, trainable=False)  # 当前轮数值初始化

    # 损失函数(softmax和交叉熵共同组成的loss,在加上正则化损失的总和)
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y, 1), logits=y_hat)
    cem = tf.reduce_mean(ce)  # 求所有元素的均值
    loss = cem + tf.add_n(tf.get_collection('losses'))  # 得到包含所有参数损失的损失函数

    # 学习率梯度衰减的模型
    '''
    decayed_learning_rate = learning_rate *decay_rate ^ (global_step / decay_steps)
    '''
    learning_rate = tf.train.exponential_decay(learning_rate_base, global_step, mnist.train.num_examples / batch_size,
                                               learning_rate_decay, staircase=True)
    # 定义参数优化方法
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step)

    # 定义参数的滑动平均
    ema = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
    ema_op = ema.apply(tf.trainable_variables())

    # 定义参数的控制依赖
    with tf.control_dependencies([train_step, ema_op]):
        train_op = tf.no_op(name='train')

    # 创建模型保存的实例化对象
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # 参数初始化
        sess.run(tf.global_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(model_save_path)

        # 判断是否有ckpt模型,有则恢复模型(这种方法简单便捷,不进行重复的训练)
        if ckpt and ckpt.model_checkpoint_path:
            # 恢复会话,继续训练模型
            saver.restore(sess, ckpt.model_checkpoint_path)

        # 模型训练,迭代
        for i in range(total_steps):
            xs, ys = mnist.train.next_batch(batch_size)
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y: ys})
            if i % 100 == 0:
                print('After {} train step(s), loss on training batch is {}'.format(step, loss_value))
                # 将当前会话加载到指定路径
                saver.save(sess, os.path.join(model_save_path, model_name), global_step=global_step)


def main():
    mnist = input_data.read_data_sets('./mnist', one_hot=True)
    backward(mnist)


if __name__ == '__main__':
    main()

四 测试代码

        ( mnist_test.py)
# _*_coding:utf-8_*_

import tensorflow as tf
import time
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward

sleep_time = 5


def mnist_test(mnist):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32, shape=[None, mnist_forward.input_node], name='x')
        y = tf.placeholder(tf.float32, shape=[None, mnist_forward.output_node], name='y')
        y_hat = mnist_forward.forward(x, None)

        ema = tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)
        ema_restore = ema.variables_to_restore()
        # 创建模型保存的实例化对象
        saver = tf.train.Saver(ema_restore)

        # 模型的准确率计算
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_hat, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        # 循环计算
        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(mnist_backward.model_save_path)
                # 判断是否有ckpt模型,有则恢复模型
                if ckpt and ckpt.model_checkpoint_path:
                    # 恢复会话
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
                    print('After {} training steps, test accuracy is {}'.format(global_step, accuracy_score))
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(sleep_time)


def main():
    mnist = input_data.read_data_sets('./mnist', one_hot=True)
    mnist_test(mnist)


if __name__ == '__main__':
    main()

五 小结

    笔者觉得文章代码中的一个亮点(较为实用,不需要每次运行代码都从头开始,节约时间!!!):

# 判断是否有ckpt模型,有则恢复模型(这种方法简单便捷,不进行重复的训练)
        if ckpt and ckpt.model_checkpoint_path:
            # 恢复会话,继续训练模型
            saver.restore(sess, ckpt.model_checkpoint_path)
    说明: 本文主要参考了' Tenorflow笔记'这门课程的内容,在这课程中还有讲解数据集的制作等没有在这里写出, 如有问题或者建议可以欢迎给作者留言,谢谢!!

你可能感兴趣的:(实例项目)