该例子不关心最后的训练精度(只有90%),只是尽可能展示一个完整TF工程。对于一些比较大的工程,单文件方式不是很合适,需要多个文件配合。从逻辑上一般分为:推理(Inference) 、损失(Loss)、训练(traning )和评估( evaluation)这些步骤,代码位置:原始代码下载地址(只需要mnist.py 和fully_connected_feed.py
)其中:mnist.py是构建模型的,另一个是运行和评价。先看代码,后面做一些说明。
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trains and Evaluates the MNIST network using a feed dictionary."""
# pylint: disable=missing-docstring
#system
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import time
from six.moves import xrange # pylint: disable=redefined-builtin
#加载数据
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist
# 模型的基本参数,类似C语言开头的defined
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size. '
'Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
'for unit testing.')
#占位符(就如同神经网络的接口):一组图像的以及对应的标签(不是one-hot的)
def placeholder_inputs(batch_size):
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
return images_placeholder, labels_placeholder
#给占位符对应真实的数据字典false表示非one-hot数据
def fill_feed_dict(data_set, images_pl, labels_pl):
# Create the feed_dict for the placeholders filled with the next `batch size` examples.
# data_set : the original data class of images and labels from input_data.read_data_sets(), like mnist in CNN_MNIST
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,FLAGS.fake_data)
# python dictionary
feed_dict = {
images_pl: images_feed,
labels_pl: labels_feed,
}
return feed_dict
#验证准确率就比较综合了:
#sess提供了下面这一堆运行空间的接口
#*_placeholder提供了训练好的网络的接口
#data_set是不同的数据集(训练集、测试集、验证集)
#eval_correct是集合数据验证正确的个数总和
def do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_set):
#获得一些相关常数
true_count = 0
steps_per_epoch = data_set.num_examples // FLAGS.batch_size
num_examples = steps_per_epoch * FLAGS.batch_size
#每一代的每一步都分开计算:因为fill_feed_dict中已经定义死了每次获取FLAGS.batch_size数据
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,images_placeholder,labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
precision = true_count / num_examples
print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
(num_examples, true_count, precision))
#训练的主体部分
def run_training():
#"train_dir" is "data", the folder contains all mnist data
#指明数据源
data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
#指定使用的操作所在的图(graph):其实大多数情况只要一个graph就可以了,这里为了完整性
with tf.Graph().as_default():
# 首先是构建网络图片数据和标签数据的占位符,指定他们的结构,好往上搭积木
images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)
# 构建网络主体,也就是推理部分(inference部分)
logits = mnist.inference(images_placeholder,FLAGS.hidden1,FLAGS.hidden2)
# 根据推理部分的结果和对应的真实标签构建损失函数
loss = mnist.loss(logits, labels_placeholder)
# 根据损失函数和学习速率构建训练过程
train_op = mnist.training(loss, FLAGS.learning_rate)
# 根据推理部分的结果和对应的真实标签计算给定数据的准确率(是一个总和)
eval_correct = mnist.evaluation(logits, labels_placeholder)
# 收集图表的信息,用于给Tensoroard提供信息
summary_op = tf.merge_all_summaries()
# 初始化网络的参数
init = tf.initialize_all_variables()
# 主要是记录训练的参数
saver = tf.train.Saver()
# 指定Session
sess = tf.Session()
# 与summary_op是配套的,用于具体地操作收集的信息,比如写到缓冲区等
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
#运行参数初始化
sess.run(init)
# 进入训练环节
for step in xrange(FLAGS.max_steps):
start_time = time.time()#先记录这一步的时间
#获取这一步用于训练的batch数据
feed_dict = fill_feed_dict(data_sets.train,
images_placeholder,
labels_placeholder)
#按照给定的数据训练(feed_dict)、给定的训练方法(train_op),训练一次网络
#顺便获得训练之前损失函数的值(注:train_op没有输出,所以是“_”)
_, loss_value = sess.run([train_op, loss],feed_dict=feed_dict)
#训练完一次,赶紧计算一下消耗了多少时间
duration = time.time() - start_time
#下面都是一些为了方便调试而输出的各种验证信息和关心的主要参数
# 每隔训练100步就把:步数、损失函数的值、使用的时间直接输出
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
# 然后把图表运行的信息写到缓冲区,更新磁盘文件
summary_str = sess.run(summary_op, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)
summary_writer.flush()
# 每隔1000步就保存训练数据一次,并验证一下训练集、验证集、测试集的准确率
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
saver.save(sess, checkpoint_file, global_step=step)
# Evaluate against the training set.
print('Training Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.train)
# Evaluate against the validation set.
print('Validation Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
# Evaluate against the test set.
print('Test Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)
def main(_):
run_training()
if __name__ == '__main__':
tf.app.run()
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
# The MNIST dataset has 10 classes, representing the digits 0 through 9.
NUM_CLASSES = 10
# The MNIST images are always 28x28 pixels.
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
#从图像输入层一直到准备链接softmax之前的10个输出(可以作为全链接的模板)
def inference(images, hidden1_units, hidden2_units):
#images : [000000...0000000]
#weights: [000000]
# [000000]
# ....
# [000000]
#biases: [000000]
with tf.name_scope('hidden1'):
weights = tf.Variable(
tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
name='weights')
biases = tf.Variable(tf.zeros([hidden1_units]),
name='biases')
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
# Hidden 2
with tf.name_scope('hidden2'):
weights = tf.Variable(
tf.truncated_normal([hidden1_units, hidden2_units],
stddev=1.0 / math.sqrt(float(hidden1_units))),
name='weights')
biases = tf.Variable(tf.zeros([hidden2_units]),
name='biases')
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
# Linear
with tf.name_scope('softmax_linear'):
weights = tf.Variable(
tf.truncated_normal([hidden2_units, NUM_CLASSES],
stddev=1.0 / math.sqrt(float(hidden2_units))),
name='weights')
biases = tf.Variable(tf.zeros([NUM_CLASSES]),
name='biases')
logits = tf.matmul(hidden2, weights) + biases
return logits
#损失函数:使用了inference最后10个输出,通过softmax,与标签计算cross_entropy的平均值作为损失
def loss(logits, labels):
#labels is not noe-hot style
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='xentropy')
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
return loss
#给可视化tensorboard传递损失值,定义训练方法
def training(loss, learning_rate):
#for visualize of loss
tf.scalar_summary(loss.op.name, loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
#counter the step num
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
return train_op
#计算准确的个数
def evaluation(logits, labels):
correct = tf.nn.in_top_k(logits, labels, 1)
#bool转变成int32,并求和
return tf.reduce_sum(tf.cast(correct, tf.int32))
1) tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name=None)
上面的函数是对tf.nn.softmax_cross_entropy_with_logits的封装(
这个函数可以把softmax的输出与one-hot标签求交叉熵)。开头的sparse_重点是有个制作one-hot标签的步骤。这个函数确实是为了适应分类问题中经常出现的标签问题设立的。上述函数等效的代码:
batch_size = tf.size(labels)
labels = tf.expand_dims(labels, 1)
indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
concated = tf.concat(1, [indices, labels])
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
onehot_labels,
name='xentropy')
2) tf.scalar_summary(tags, values, collections=None, name=None)
3) tf.nn.in_top_k(predictions, targets, k, name=None)
4) class tf.train.SummaryWriter
5) tf.merge_all_summaries(key='summaries')
6) tf.train.SummaryWriter.add_summary(summary, global_step=None)
7) tf.train.SummaryWriter.flush()
8) class tf.train.Saver
9) tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True)
10) tf.train.Saver.restore(sess, save_path)