下例是用tensorflow写的2层cnn+2层fc的一个卷积神经网络做mnist的分类例子,旨在简单明了,过一遍TF代码的分类流程。
实例只有两个文件:
train.py:数据读取,模型训练。
# coding=utf-8
import tensorflow as tf
import model
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('dataset/', one_hot=True)
tf.app.flags.DEFINE_integer('image_width', 28, 'width of image')
tf.app.flags.DEFINE_integer('image_height', 28, 'height of image')
tf.app.flags.DEFINE_integer('channel', 1, 'channel of image')
tf.app.flags.DEFINE_float('keep_drop', 1.0, 'keep drop out')
tf.app.flags.DEFINE_float('lr', 0.001, 'learning rate')
tf.app.flags.DEFINE_integer('batch_size', 32, 'batch size')
tf.app.flags.DEFINE_integer('epochs', 100, 'num of epoch')
tf.app.flags.DEFINE_integer('num_classes', 10, 'num of class')
tf.app.flags.DEFINE_string('checkpoints', './checkpoints/model.ckpt', 'path of checkpoints')
tf.app.flags.DEFINE_boolean('continue_training', False, 'continue')
FLAGS = tf.app.flags.FLAGS
def main(_):
input = tf.placeholder(dtype=tf.float32, shape=[None, FLAGS.image_width*FLAGS.image_height])
output = tf.placeholder(dtype=tf.int32, shape=[None, FLAGS.num_classes])
# Control GPU resource utilization
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# build network
logits = model.build(input, FLAGS.image_height, FLAGS.image_width, FLAGS.channel, FLAGS.keep_drop, True)
# loss
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=output))
# optimiter
train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(cross_entropy)
# evalution
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(output, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
with sess.as_default():
# initial
saver = tf.train.Saver(max_to_keep=1000)
sess.run(tf.global_variables_initializer())
# Restore weights file
if FLAGS.continue_training:
saver.restore(sess, FLAGS.checkpoints)
# begin train
for epoch in range(FLAGS.epochs):
for k in range(int(mnist.train.num_examples / FLAGS.batch_size)):
batch = mnist.train.next_batch(FLAGS.batch_size)
_, network, loss, acc = sess.run([train_op, logits, cross_entropy, accuracy], feed_dict={input: batch[0], output: batch[1]})
print('loss : %f accuracy : %f'% (loss, acc))
print('精确率:', accuracy.eval({input: mnist.test.images, output: mnist.test.labels}))
# Create directories if needed
if not os.path.isdir("checkpoints"):
os.makedirs("checkpoints")
saver.save(sess, "%s/model.ckpt" % ("checkpoints"))
if __name__ == '__main__':
tf.app.run()
model.py:网络搭建。
import tensorflow as tf
import numpy as np
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
def build(inputs, height, width, channel, keep_drop, train):
x_image = tf.reshape(inputs, [-1, height, width, channel])
# block_1
weight_1 = weight_variable(([5, 5, 1, 32]))
bias_1 = bias_variable([32])
conv_1 = tf.nn.relu(conv2d(x_image, weight_1) + bias_1)
pool_1 = max_pool(conv_1)
# block_2
weight_2 = weight_variable([5, 5, 32, 64])
bias_2 = bias_variable([64])
conv_2 = tf.nn.relu(conv2d(pool_1, weight_2) + bias_2)
pool_2 = max_pool(conv_2)
# fc_1
fc_weight_1 = weight_variable([7 * 7 * 64, 1024])
fc_bias_1 = bias_variable([1024])
flat = tf.reshape(pool_2, [-1, 7 * 7 * 64])
fc_1 = tf.nn.relu(tf.matmul(flat, fc_weight_1) + fc_bias_1)
# Dropout
if train == True:
fc_1 = tf.nn.dropout(fc_1, keep_prob=keep_drop)
fc_weight_2 = weight_variable([1024, 10])
fc_bias_2 = bias_variable([10])
logits = tf.nn.softmax(tf.matmul(fc_1, fc_weight_2) + fc_bias_2)
return logits
运行结果:
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
loss : 1.461150 accuracy : 1.000000
精确率: 0.9924
本例子结构非常简单,如有细节上或其他问题欢迎留言讨论。