实现手写体 mnist 数据集的识别任务,共分为三个模块文件,分别是:
import tensorflow as tf
INPUT_NODE = 784 #网络输入节点为784个(代表每张输入图片的像素个数)
LAYER1_NODE = 500 #隐藏层节点500个
OUTPUT_NODE = 10 #输出节点10个(表示输出为数字0-9的十个分类)
#输入层到隐藏层的参数w1形状为[784, 500],由隐藏层到输出层的参数w2形状为[500, 10]
def get_weight(shape, regularizer):
w = tf.Variable(tf.truncated_normal(shape, stddev = 0.1)) #生成去掉超过0.1个标准差的正态分布的随机数
if regularizer != None:tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w)) #参数使用l2正则化,将每个参数的正则化损失加到总损失中
return w
def get_bias(shape):
b = tf.Variable(tf.zeros(shape)) #将偏置b的值初始化全为0
return b
def forward(x, regularizer):
w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)
b1 = get_bias([LAYER1_NODE])
y1 = tf.nn.relu(tf.matmul(x, w1) + b1) #前向传播结构第一层(输入x与参数w1矩阵相乘加上偏置b1再经过relu函数,得到隐藏层输出y1)
w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)
b2 = get_bias([OUTPUT_NODE])
y = tf.matmul(y1, w2) + b2 ##前向传播结构第二层(隐藏层输出y1与参数w2矩阵相乘加上偏置b2,得到输出y),由于输出y要经过softmax函数,使其符合概率分布,故输出y不经过relu函数
return y
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os
BATCH_SIZE = 200 #每轮喂入神经网络的图片数
LEARNING_RATE_BASE = 0.1 #初始学习率
LEARNING_RATE_DECAY = 0.99 #学习率衰减率
REGULARIZER = 0.0001 #正则化系数
STEPS = 50000 #训练轮数
MOVING_AVERAGE_DECAY = 0.99 #滑动平均衰减率,一般会赋接近 1 的值
MODEL_SAVE_PATH = './model/' #模型保存路径
MODEL_NAME = 'mnist_model' #模型保存名称
def backward(mnist):
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE]) #利用placeholder占位,并设置正则化
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
y = mnist_forward.forward(x, REGULARIZER) #调用mnist_forward文件中的前向传播过程forward()函数,并设置正则化,计算训练集上预测结果y
global_step = tf.Variable(0, trainable = False) #给当前计算轮数计数器赋值,设置为不可训练类型
#以下三步实现输出y经过softmax函数
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, labels = tf.argmax(y_, 1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses'))
#设定指数衰减学习率learning_rate
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples/BATCH_SIZE, LEARNING_RATE_DECAY, staircase = True)
#将反向传播的方法设置为梯度下降算法,对模型进行优化,降低损失函数
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step = global_step)
#定义参数的滑动平均
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op(name = 'train')
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
#加入断点续训后
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
for i in range(STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE) #将输入数据和标签数据输入神经网络
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict = {x: xs, y_: ys})
if i % 1000 == 0:
print('After', step, 'training step(s), loss on training batch is', loss_value)
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step = global_step) #将当前会话加载到指定路径
#加载指定路径下的训练数据集
def main():
mnist = input_data.read_data_sets('./data/', one_hot = True)
backward(mnist) #调用规定的backward()函数训练模型
if __name__ == '__main__':
main()
#coding:utf-8
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward
TEST_INTERVAL_SECS = 5 #规定程序5秒的循环间隔时间
def test(mnist):
with tf.Graph().as_default() as g: #利用tf.Graph()复现之前定义的计算图
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE]) #占位
y = mnist_forward.forward(x, None) #计算训练数据集上的预测结果y
#实例化具有滑动平均的saver对象,从而在会话被加载时模型中的所有参数被赋值为各自的滑动平均值,增强模型的稳定性
ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)
#计算模型在测试集上的准确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#with结构中,加载指定路径下的ckpt,若模型存在,则加载出模型到当前对话,在测试集上进行准确率验证,并打印出当前轮数下的准确率
#如果模型不存在,则打印出模型不存在的提示
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy, feed_dict = {x:mnist.test.images, y_:mnist.test.labels})
print('After', global_step, 'training step(s), test accuracy = ', accuracy_score)
else:
print('No checkpoint file found')
return
time.sleep(TEST_INTERVAL_SECS)
def main():
mnist = input_data.read_data_sets('./data/', one_hot = True)
test(mnist)
if __name__ == '__main__':
main()