运用tensorflow全连接神经网络进行MNIST手写数字图像识别

本文记录tensorflow搭建简单神经网络,并进行模块化处理,目的在于总结并提取简单神经网络搭建的基本思想和方法,提炼核心结构和元素,从而能够移植到日后深入学习中去。

  • 1 模块提炼
    • 1.1 template_forward.py
      • a 结构分析
      • b 代码分析
    • 1.2 template_backward.py
      • a 结构分析
      • b 代码分析
  • 2 MNIST手写数字识别——模块化编写
    • 2.1 forward.py
    • 2.2 backward.py
    • 2.3 evaluate.py


1 模块提炼

1.1 template_forward.py

a 结构分析

  • forward.py用于构建网络图结构,具体分为以下几步:
    • forward()主方法 – 设计网络层数和维度
    • get_weight() – 传入维度正则化信息,生成符合要求的weight
    • get_bias() – 传入维度信息,生成符合要求的bias

b 代码分析

import tensorflow as tf

# 主方法,定义前向传播网络结构
def forward(x, regularizer):
    w = 
    b = 
    y = 
    return y

# 获取权重变量
def get_weight(shape, regularizer=None):
    """
    传入指定的shape和regularizer(lambda)
    返回tensorflow的Variable类型变量,用于优化weight
    """
    w = tf.Variable()
    if regularizer:
        tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

# 获取偏置变量
def get_bias(shape):
    """
    传入指定的shape
    返回tensorflow的Variable类型变量,用于优化bias
    """
    b = tf.Variable( )
    return b

1.2 template_backward.py

a 结构分析

  • backward.py用于构建网络图结构,具体分为以下几步:
    • 定义常量:
      • STEPS:总训练轮数
      • BATCH_SIZE:每batch训练样本数
      • LEARNING_RATE_BASE:学习率初值,作为指数衰减学习率的初始值
      • LEARNING_RATE_DECAY:学习率衰减基数,作为指数衰减项的基底
      • REGULARIZER:正则化强度 λ λ
    • backward()主方法
      • 定义输入张量网络输出
      • 定义全局计数器损失函数指数衰减学习率单步训练优化器
      • 定义滑动平均(可选)
      • 启动会话,开始训练

b 代码分析

import tensorflow as tf 

# 定义各类常量
STEPS = 40000
BATCH_SIZE = 30
LEARNING_RATE_BASE = 0.001
LEARNING_RATE_DECAY = 0.999
REGULARIZER = 0.01
CYCLE_OBSERVED = 2000

def backward():
    # 定义占位符,其中x代表训练数据特征,y_代表训练数据标签
    x = tf.placeholder( )
    y_ = tf.placeholder( )
    # 根据forward模块定义网络输出预测值操作y
    y = forward.forward(x, REGULARIZER)

    # 定义全局计数器,用于对学习率/滑动平均的控制
    global_step = tf.Variable(0, trainable=False)
    # (选其一)
    # 最小二乘损失函数,一般用于二分类或者回归分析
    loss_mse = tf.reduce_mean(tf.square(y_ - y))
    loss = loss_mse + tf.add_n(tf.get_collection('losses'))
    # 交叉熵损失函数,用于多分类
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cem = tf.reduce_mean(ce)
    loss = cem + tf.add_n(tf.get_collection('losses'))
    # 定义指数衰减学习率
    learning_rate = tf.train.exponential_decay(
            LEARNING_RATE_BASE, 
            global_step,
            train_set_count / BATCH_SIZE,
            LEARNING_RATE_DECAY,
            staircase=True 
        )

    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

    # 定义参数的滑动平均,也需要用到全局计数器(可选)
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    ema_op = ema.apply(tf.trainable_variables())
    # 依赖控制,即train_step和ema_op操作完成后进行后续操作
    # 此处使用no_op方法,紧紧将前两个操作打包称作'train'
    with tf.control_dependencies([train_step, ema_op]):
        train_op = tf.no_op(name='train')

    # 开启会话
    with tf.Session() as sess:
        # 全局变量初始化
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        for i in range(STEPS):
            # 自定义batch操作
            sess.run(train_step, feed_dict={})
            # 每训练CYCLE_OBSERVED轮数,打印训练信息
            if i % CYCLE_OBSERVED == 0:
                print()

if __name__ == '__main__':
    backward()

2 MNIST手写数字识别——模块化编写

2.1 forward.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/7/26 09:15
# @Author  : zhoujl
# @Site    : 
# @File    : forward.py
# @Software: PyCharm
import tensorflow as tf

INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER_1_NODE = 500

def forward(x, regularizer):
    w1 = get_weight(shape=[INPUT_NODE, LAYER_1_NODE], regularizer=regularizer)
    b1 = get_bias(shape=[LAYER_1_NODE])
    y1 = tf.nn.relu(tf.matmul(x, w1) + b1)

    w2 = get_weight(shape=[LAYER_1_NODE, OUTPUT_NODE], regularizer=regularizer)
    b2 = get_bias(shape=[OUTPUT_NODE])
    y = tf.matmul(y1, w2) + b2

    return y


def get_weight(shape, regularizer=None):
    w = tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1))
    if regularizer:
        tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w


def get_bias(shape):
    b = tf.Variable(tf.constant(0.01, shape=shape))
    return b

2.2 backward.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/7/26 09:29
# @Author  : zhoujl
# @Site    : 
# @File    : backward.py
# @Software: PyCharm
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import forward

STEPS = 30000
LOG_CYCLE = 1000
BATCH_SIZE = 200
LEARNING_RATE_BASE = 0.001
LEARNING_RATE_DECAY = 0.999
REGULARIZER = 0.0001
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = './model/'
MODEL_NAME = 'mnist_model'


def backward(mnist):
    x = tf.placeholder(tf.float32, shape=[None, forward.INPUT_NODE])
    y_ = tf.placeholder(tf.float32, shape=[None, forward.OUTPUT_NODE])
    y = forward.forward(x, REGULARIZER)
    global_step = tf.Variable(0, trainable=False)

    # sparse_softmax_cross_entropy_with_logits方法,
    # logits.shape=(BATCH_SIZE, 10), labels.shape=(BATCH_SIZE),且labels必须为int
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cem = tf.reduce_mean(ce)
    loss = cem + tf.add_n(tf.get_collection('losses'))

    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        # 训练集的样本数量
        mnist.train.num_examples / BATCH_SIZE,
        LEARNING_RATE_DECAY,
        staircase=True
    )

    # 此处global_step真正成为全局计数器
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    ema_op = ema.apply(tf.trainable_variables())
    # 等待train_step和ema_op操作结束之后, 再进行下一操作
    # 此处下一步无实际操作,仅将两者重新命名
    with tf.control_dependencies([train_step, ema_op]):
        train_op = tf.no_op(name='train')

    saver = tf.train.Saver()
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        for i in range(STEPS):
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            _, loss_val = sess.run([train_op, loss], feed_dict={x: xs, y_: ys})
            if i % LOG_CYCLE == 0:
                print('Iter {}, loss is {}'.format(i, loss_val))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)


if __name__ == '__main__':
    mnist = input_data.read_data_sets('./data/', one_hot=True)
    backward(mnist)

2.3 evaluate.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/7/26 10:31
# @Author  : zhoujl
# @Site    : 
# @File    : evaluation.py
# @Software: PyCharm
import os
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import forward
import backward


def evaluate(mnist):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32, shape=[None, forward.INPUT_NODE])
        y_ = tf.placeholder(tf.float32, shape=[None, forward.OUTPUT_NODE])
        # 测试准确率阶段不需要正则化
        y = forward.forward(x, None)

        # 读取模型文件中滑动平均参数的影子值
        ema = tf.train.ExponentialMovingAverage(backward.MOVING_AVERAGE_DECAY)
        saver = tf.train.Saver(ema.variables_to_restore())

        # 计算准确率
        correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(backward.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    # 根据文件名提取global_step值
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    # 达到训练最大值,跳出循环
                    accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
                    print('Iter {}, test accuracy is {}'.format(global_step, accuracy_score))
                else:
                    print('No checkpoint file found!')
            time.sleep(5)


def main():
    mnist = input_data.read_data_sets('./data/', one_hot=True)
    evaluate(mnist)


if __name__ == '__main__':
    main()

你可能感兴趣的:(tensorflow)