MNISTs手写数字识别源代码

model.py

import tensorflow as tf


class Network:
    def __init__(self):
        self.learning_rate = 0.001
        # 记录已经训练的次数
        self.global_step = tf.Variable(0, trainable=False)

        # 输入张量 28 * 28 = 784个像素的图片一维向量
        self.x = tf.placeholder(tf.float32, [None, 784])
        # 标签值,即图像对应的结果,如果对应数字是8,则对应label是 [0,0,0,0,0,0,0,0,1,0]
        # 标签是一个长度为10的一维向量,值最大的下标即图片上写的数字,采用独热编码,最大的值即是标签对应的值
        self.label = tf.placeholder(tf.float32, [None, 10])

        #权重,全部初始化为0
        self.w = tf.Variable(tf.zeros([784, 10]))
        # 偏置 bias, 初始化全 0
        self.b = tf.Variable(tf.zeros([10]))
        # 输出 y = softmax(X * w + b)
        self.y = tf.nn.softmax(tf.matmul(self.x, self.w) + self.b)

        # 损失,即交叉熵,最常用的计算标签(label)与输出(y)之间差别的方法
        self.loss = -tf.reduce_sum(self.label * tf.log(self.y + 1e-10))

        # 反向传播,采用梯度下降的方法。调整w与b,使得损失(loss)最小
        # loss越小,那么计算出来的y值与 标签(label)值越接近,准确率越高
        # minimize 可传入参数 global_step, 每次训练 global_step的值会增加1
        # 因此,可以通过计算self.global_step这个张量的值,知道当前训练了多少步
        self.train = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(
            self.loss, global_step=self.global_step)

        # argmax 返回最大值的下标,最大值的下标即答案
        predict = tf.equal(tf.argmax(self.label, 1), tf.argmax(self.y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(predict, "float"))

train.py

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from model import Network

CKPT_DIR = 'ckpt'


class Train:
    def __init__(self):
        self.net = Network()
        # 初始化 session
        # Network() 只是构造了一张计算图,计算需要放到会话(session)中
        self.sess = tf.Session()
        # 初始化变量
        self.sess.run(tf.global_variables_initializer())
        # 读取训练和测试数据,这是tensorflow库自带的,不存在训练集会自动下载
        # 项目目录下已经下载好,删掉后,重新运行代码会自动下载
        self.data = input_data.read_data_sets('E:/mnist_tu_code/data_set', one_hot=True)

    def train(self):
        # batch_size 是指每次迭代训练,传入训练的图片张数。
        # 总的训练次数
        batch_size = 64
        train_step = 30000

        # 记录训练次数, 初始化为0
        step = 0

        # 每隔1000步保存模型
        save_interval = 100

        # tf.train.Saver是用来保存训练结果的。
        # max_to_keep 用来设置最多保存多少个模型,默认是5
        # 如果保存的模型超过这个值,最旧的模型将被删除
        saver = tf.train.Saver(max_to_keep=10)

        # 开始训练前,检查ckpt文件夹,看是否有checkpoint文件存在。
        # 如果存在,则读取checkpoint文件指向的模型,restore到sess中。
        ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(self.sess, ckpt.model_checkpoint_path)
            # 读取网络中的global_step的值,即当前已经训练的次数
            step = self.sess.run(self.net.global_step)
            print('Continue from')
            print('        -> Minibatch update : ', step)

        while step < train_step:
            # 从数据集中获取 输入和标签(也就是答案)
            x, label = self.data.train.next_batch(batch_size)
            # 每次计算train,更新整个网络
            # loss只是为了看到损失的大小,方便打印
            _, loss = self.sess.run([self.net.train, self.net.loss],
                                    feed_dict={self.net.x: x, self.net.label: label})
            step = self.sess.run(self.net.global_step)
            if step % 1000 == 0:
                print('第%5d步,当前loss:%.2f' % (step, loss))

            # 模型保存在ckpt文件夹下
            # 模型文件名最后会增加global_step的值,比如1000的模型文件名为 model-1000
            if step % save_interval == 0:
                saver.save(self.sess, CKPT_DIR + '/model', global_step=step)

    def calculate_accuracy(self):   #计算准确率
        test_x = self.data.test.images
        test_label = self.data.test.labels
        accuracy = self.sess.run(self.net.accuracy,
                                 feed_dict={self.net.x: test_x, self.net.label: test_label})
        print("准确率: %.2f,共测试了%d张图片 " % (accuracy, len(test_label)))


if __name__ == "__main__":
    app = Train()
    app.train()
    app.calculate_accuracy()

predict.py

import tensorflow as tf
import numpy as np
from PIL import Image

from model import Network

'''
使用tensorflow的模型来预测手写数字
输入是28 * 28像素的图片,输出是个具体的数字
'''

CKPT_DIR = 'ckpt'


class Predict:
    def __init__(self):
        self.net = Network()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())

        # 加载模型到sess中
        self.restore()

    def restore(self):
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            raise FileNotFoundError("未保存任何模型")

    def predict(self, image_path):
        # 读图片并转为黑白的
        img = Image.open(image_path).convert('L')
        flatten_img = np.reshape(img, 784)
        x = np.array([1 - flatten_img])
        y = self.sess.run(self.net.y, feed_dict={self.net.x: x})

        # 因为x只传入了一张图片,取y[0]即可
        # np.argmax()取得独热编码最大值的下标,即代表的数字
        print(image_path)
        print('        -> Predict digit', np.argmax(y[0]))


if __name__ == "__main__":
    app = Predict()
    app.predict('E:/mnist_tu_code/test_images/0.png')
    app.predict('E:/mnist_tu_code/test_images/1.png')
    app.predict('E:/mnist_tu_code/test_images/4.png')

你可能感兴趣的:(学习笔记)