01 定义神经网络的前向传播过程
import tensorflow as tf
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
def get_weight_variable(shape, regularizer):
weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
tf.add_to_collection('losses', regularizer(weights))
return weights
def inference(input_tensor, regularizer):
with tf.variable_scope('layer1'):
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
with tf.variable_scope('layer2'):
weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases
return layer2
02 训练程序
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import os
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"
def train(mnist):
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
y = mnist_inference.inference(x, regularizer)
global_step = tf.Variable(0, trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
def main(argv=None):
mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
train(mnist)
if __name__ == '__main__':
main()
'''
After 1 training step(s), loss on training batch is 2.79889.
After 1001 training step(s), loss on training batch is 0.224816.
After 2001 training step(s), loss on training batch is 0.154148.
After 3001 training step(s), loss on training batch is 0.138829.
After 4001 training step(s), loss on training batch is 0.116191.
After 5001 training step(s), loss on training batch is 0.114107.
After 6001 training step(s), loss on training batch is 0.0982979.
After 7001 training step(s), loss on training batch is 0.0957554.
After 8001 training step(s), loss on training batch is 0.0807435.
After 9001 training step(s), loss on training batch is 0.073724.
After 10001 training step(s), loss on training batch is 0.0664153.
After 11001 training step(s), loss on training batch is 0.0645105.
After 12001 training step(s), loss on training batch is 0.0572586.
After 13001 training step(s), loss on training batch is 0.0571209.
After 14001 training step(s), loss on training batch is 0.0620913.
After 15001 training step(s), loss on training batch is 0.0530401.
After 16001 training step(s), loss on training batch is 0.0473576.
After 17001 training step(s), loss on training batch is 0.046784.
After 18001 training step(s), loss on training batch is 0.0434549.
After 19001 training step(s), loss on training batch is 0.0432269.
After 20001 training step(s), loss on training batch is 0.0423441.
After 21001 training step(s), loss on training batch is 0.0418092.
After 22001 training step(s), loss on training batch is 0.0371407.
After 23001 training step(s), loss on training batch is 0.0368375.
After 24001 training step(s), loss on training batch is 0.0379352.
After 25001 training step(s), loss on training batch is 0.0358277.
After 26001 training step(s), loss on training batch is 0.034691.
After 27001 training step(s), loss on training batch is 0.0370471.
After 28001 training step(s), loss on training batch is 0.0342903.
After 29001 training step(s), loss on training batch is 0.0363797.
'''
03 测试程序
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
EVAL_INTERVAL_SECS = 10
def evaluate(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
y = mnist_inference.inference(x, None)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))
else:
print('No checkpoint file found')
return
time.sleep(EVAL_INTERVAL_SECS)
def main(argv=None):
mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
evaluate(mnist)
if __name__ == '__main__':
main()
'''
After 29001 training step(s), validation accuracy = 0.9844
...
'''