本程序使用TensorFlow实现输入手写数字识别结果,IDE为Pycharm。实现的主要功能是实现断点续训,输入真实图片,输出预测值。
有完整代码。分为四个文件:
forward.py
backward.py
test.py:测试已经训练好的神经网络,查看正确率
app.py:实现应用,输入图片,实现识别技术。
本NN采用两层的全连接网络,输入节点数为784,中间节点数为500,输出10分类。
全连接层结构:
[1,784] ->[1,500]->[1,10]
前向传播过程,其中INPUT_NODE=784, LAYER1_NODE=500,OUTPUT_NODE=10.
前向传播代码:
# -*-coding:gbk-*-
import tensorflow as tf
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
def get_weight(shape, regularizer):
w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
if regularizer != None:
# collection容器可以保存很多值,这里使用L2正则化,在w的损失加入到losses中
tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
return w
def get_bias(shape):
print (shape)
b = tf.Variable(tf.zeros(shape))
return b
def forward(x, regularizer):
w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)
b1 = get_bias([LAYER1_NODE])
y1 = tf.nn.relu(tf.matmul(x, w1) + b1)
w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)
b2 = get_bias([OUTPUT_NODE])
y = tf.matmul(y1, w2) + b2 # 输出层不过激活
return y
在方向传播中,把模型保存指定路径,注意路径文件夹要先创建文件,否则可能出错。
断点训练可以是把训练好的模型保存下来,再次使用不需要从头开始训练,而是从之前断开的位置开始,使用ckpt可以实现复现的计算图。
# tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
# 函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。
# 参数说明:
# checkpoint_dir:表示存储断点文件的目录
# latest_filename=None:断点文件的可选名称,默认为“checkpoint”
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# saver.restore(sess, ckpt.model_checkpoint_path)
# 该函数表示恢复当前会话,将 ckpt 中的值赋给 w 和 b。
# 参数说明:
# sess:表示当前会话,之前保存的结果将被加载入这个会话
# ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看 checkpoint 文件
#coding:utf-8
# -*-coding:gbk-*-
# 0导入模块,生成数据集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import forward
STEPS = 50000
BATCH_SIZE = 200
LEARNING_RATE_BASE = 0.1
LEARNING_RATE_DECAY = 0.99
GEGULARIZER = 0.0001
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "./model/"
MODEL_NAME = "mnist_model"
def backward(mnist):
x = tf.placeholder(tf.float32, [None, forward.INPUT_NODE])
# y_->labels y->logist
y_ = tf.placeholder(tf.float32, [None, forward.OUTPUT_NODE])
y = forward.forward(x, GEGULARIZER)
global_step = tf.Variable(0, trainable=False)
# 定义loss函数
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses')) # 加上w的损失
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step, # 为样本个数
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True)
# 定义backward 方法,包括正则化
#train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
# 在模型训练时候使用滑动平均,模型更加健壮
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op(name='train')
saver = tf.train.Saver() # 实例化saver对象
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
#init_op = tf.global_variables_initializer()
sess.run(init_op)#执行训练过程
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess,ckpt.model_checkpoint_path)
# 训练模型
for i in range(STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE) # 随机抽取BATCH_SIZE数据输入NN,xs:(200,784),ys:(200,10)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={
x: xs, y_: ys})
if i % 1000 == 0:
print("After %d step(s),loss on all data is %g" % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
# NN每隔1000轮,将参数信息保存到指定路径,并注明训练轮数
def practice(mnist):
print "train data size:",mnist.train.num_examples
print("validation data size:",mnist.validation.num_examples)
print ("test data size:",mnist.test.num_examples)
print mnist.train.labels[0]
print mnist.train.images[0]
def main():
mnist = input_data.read_data_sets("./data/", one_hot=True)
#practice(mnist)
backward(mnist)
if __name__ == '__main__':
main()
# coding:utf-8
import sys
sys.path.append('/usr/local/lib/python2.7/dist-packages')
# 0导入模块,生成数据集
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import forward
import time
import backward
TEST_INTERVAL_SECS = 5
def test(mnist):
# 复现计算图
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, forward.OUTPUT_NODE])
y = forward.forward(x, None)
# 实例化可还原滑动平均的saver
ema = tf.train.ExponentialMovingAverage(backward.MOVING_AVERAGE_DECAY)
ema_restore = ema.variables_to_restore()
saver = tf.train.Saver(ema_restore)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
while True:
with tf.Session() as sess:
# 加载训练好的模型
ckpt = tf.train.get_checkpoint_state(backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 恢复回话
saver.restore(sess, ckpt.model_checkpoint_path)
# 恢复轮数,使用split函数获得已经训练的轮数
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
# 计算准确率
accuracy_score = sess.run(accuracy, feed_dict={
x: mnist.test.images, y_: mnist.test.labels})
print("After %s training step(s),test accuracy = %g " % (global_step, accuracy_score))
else:
print("NO checkpoint file found")
return
time.sleep(TEST_INTERVAL_SECS)
def main():
mnist = input_data.read_data_sets("./data/", one_hot=True)
test(mnist)
if __name__ == '__main__':
main()
#coding:utf-8
import tensorflow as tf
import numpy as np
import forward
import backward
from PIL import Image
# 图片预处理
def pre_pic(testPic):
img = Image.open(testPic)
img.show()
# 改变图片规格,适应神经网络的输入规格
reIm = img.resize((28,28),Image.ANTIALIAS)
im_arr = np.array(reIm.convert('L'))
threshold = 50 # 设定阈值,进行二值化
for i in range(28):
for j in range(28):
im_arr[i][j] = 255- im_arr[i][j]
if(im_arr[i][j]<threshold):
im_arr[i][j] = 1
else: im_arr[i][j]=255
nm_arr = im_arr.reshape([1,784])
nm_arr = nm_arr.astype(np.float32)
img_arr = np.multiply(nm_arr,1.0/255)
# 变成一维列表return
return img_arr
def restore_model(testPicArr):
with tf.Graph().as_default() as tg:
x = tf.placeholder(tf.float32,[None,forward.INPUT_NODE])
y = forward.forward(x,None)
preValue = tf.arg_max(y,1)
variable_averages = tf.train.ExponentialMovingAverage(backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 恢复会话
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue,feed_dict={
x:testPicArr})
return preValue
else:
print("NO checkpoint file found")
return -1
def application():
testNum = input("input the number of test pictures:")
for i in range(testNum):
testPic = raw_input("the path of test picture:")
testPicArr = pre_pic(testPic)
preValue = restore_model(testPicArr)
print "The prediction number is:",preValue
def main():
application()
if __name__ == '__main__':
main()