使用TensorFlow完成MNSIT手写体识别(全代码及解析)(适合初学者)

采用全连接网络(Full-Connected Networks)完成MNIST数据的手写体识别

Github文件地址在https://github.com/tensorflow/tensorflow,
文件目录在: https://github.com/tensorflow/tensorflow/examples/tutorials/mnist/中

Tensor即张量的理解

Markdown Extra 表格语法:

名称 python表示
vector v=[1,2,3]
matrix m= [[l, 2, 3], [4, 5, 6], [7, 8, 9]]
tensor t = [[2, 4, 6], [[8], [10], [12]], [[14], [16], [18]]]
Name 名称 python表示
Vector 向量 v=[1,2,3]
Matrix 矩阵(数据表) m= [[l, 2, 3], [4, 5, 6], [7, 8, 9]]
Tensor ex : 3阶张量 t = [[2, 4, 6], [[8], [10], [12]], [[14], [16], [18]]]

fully_connected_feed.py 解析

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Trains and Evaluates the MNIST network using a feed dictionary."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=missing-docstring
import argparse
import os
import sys
import time

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist

# Basic model parameters as external flags.
FLAGS = None


def placeholder_inputs(batch_size):
  """Generate placeholder variables to represent the input tensors.

  These placeholders are used as inputs by the rest of the model building
  code and will be fed from the downloaded data in the .run() loop, below.

  Args:
    batch_size: The batch size will be baked into both placeholders.

  Returns:
    images_placeholder: Images placeholder.
    labels_placeholder: Labels placeholder.
  """
  # Note that the shapes of the placeholders match the shapes of the full
  # image and label tensors, except the first dimension is now batch_size
  # rather than the full size of the train or test data sets.
  images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                         mnist.IMAGE_PIXELS))
  labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
  return images_placeholder, labels_placeholder


def fill_feed_dict(data_set, images_pl, labels_pl):
  """Fills the feed_dict for training the given step.

  A feed_dict takes the form of:
  feed_dict = {
      : ,
      ....
  }

  Args:
    data_set: The set of images and labels, from input_data.read_data_sets()
    images_pl: The images placeholder, from placeholder_inputs().
    labels_pl: The labels placeholder, from placeholder_inputs().

  Returns:
    feed_dict: The feed dictionary mapping from placeholders to values.
  """
  # Create the feed_dict for the placeholders filled with the next
  # `batch size` examples.
  images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                 FLAGS.fake_data)
  feed_dict = {
      images_pl: images_feed,
      labels_pl: labels_feed,
  }
  return feed_dict


def do_eval(sess,
            eval_correct,
            images_placeholder,
            labels_placeholder,
            data_set):
  """Runs one evaluation against the full epoch of data.

  Args:
    sess: The session in which the model has been trained.
    eval_correct: The Tensor that returns the number of correct predictions.
    images_placeholder: The images placeholder.
    labels_placeholder: The labels placeholder.
    data_set: The set of images and labels to evaluate, from
      input_data.read_data_sets().
  """
  # And run one epoch of eval.
  true_count = 0  # Counts the number of correct predictions.
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    feed_dict = fill_feed_dict(data_set,
                               images_placeholder,
                               labels_placeholder)
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = float(true_count) / num_examples
  print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
        (num_examples, true_count, precision))


def run_training():
  """Train MNIST for a number of steps."""
  # Get the sets of images and labels for training, validation, and
  # test on MNIST.
  data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

  # Tell TensorFlow that the model will be built into the default Graph.
  with tf.Graph().as_default():
    # Generate placeholders for the images and labels.
    images_placeholder, labels_placeholder = placeholder_inputs(
        FLAGS.batch_size)

    # Build a Graph that computes predictions from the inference model.
    logits = mnist.inference(images_placeholder,
                             FLAGS.hidden1,
                             FLAGS.hidden2)

    # Add to the Graph the Ops for loss calculation.
    loss = mnist.loss(logits, labels_placeholder)

    # Add to the Graph the Ops that calculate and apply gradients.
    train_op = mnist.training(loss, FLAGS.learning_rate)

    # Add the Op to compare the logits to the labels during evaluation.
    eval_correct = mnist.evaluation(logits, labels_placeholder)

    # Build the summary Tensor based on the TF collection of Summaries.
    summary = tf.summary.merge_all()

    # Add the variable initializer Op.
    init = tf.global_variables_initializer()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

    # And then after everything is built:

    # Run the Op to initialize the variables.
    sess.run(init)

    # Start the training loop.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # Fill a feed dictionary with the actual set of images and labels
      # for this particular training step.
      feed_dict = fill_feed_dict(data_sets.train,
                                 images_placeholder,
                                 labels_placeholder)

      # Run one step of the model.  The return values are the activations
      # from the `train_op` (which is discarded) and the `loss` Op.  To
      # inspect the values of your Ops or variables, you may include them
      # in the list passed to sess.run() and the value tensors will be
      # returned in the tuple from the call.
      _, loss_value = sess.run([train_op, loss],
                               feed_dict=feed_dict)

      duration = time.time() - start_time

      # Write the summaries and print an overview fairly often.
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        summary_str = sess.run(summary, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      # Save a checkpoint and evaluate the model periodically.
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # Evaluate against the validation set.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # Evaluate against the test set.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)


def main(_):
  if tf.gfile.Exists(FLAGS.log_dir):
    tf.gfile.DeleteRecursively(FLAGS.log_dir)
  tf.gfile.MakeDirs(FLAGS.log_dir)
  run_training()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--learning_rate',
      type=float,
      default=0.01,
      help='Initial learning rate.'
  )
  parser.add_argument(
      '--max_steps',
      type=int,
      default=2000,
      help='Number of steps to run trainer.'
  )
  parser.add_argument(
      '--hidden1',
      type=int,
      default=128,
      help='Number of units in hidden layer 1.'
  )
  parser.add_argument(
      '--hidden2',
      type=int,
      default=32,
      help='Number of units in hidden layer 2.'
  )
  parser.add_argument(
      '--batch_size',
      type=int,
      default=100,
      help='Batch size.  Must divide evenly into the dataset sizes.'
  )
  parser.add_argument(
      '--input_data_dir',
      type=str,
      default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                           'tensorflow/mnist/input_data'),
      help='Directory to put the input data.'
  )
  parser.add_argument(
      '--log_dir',
      type=str,
      default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                           'tensorflow/mnist/logs/fully_connected_feed'),
      help='Directory to put the log data.'
  )
  parser.add_argument(
      '--fake_data',
      default=False,
      help='If true, uses fake data for unit testing.',
      action='store_true'
  )

  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

以下为部分代码解析

120 data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_ data) 

120 行,准备训练、验证和测试数据集。这里 TensorFlow 提供了内置模块可以直接操作下载 MNIST datasets 数据集。

 123 with tf.Graph().as_default(): 

123 行,使用默认图( graph), TensorFlow 里使用图来表示计算任务,图中的节点被称 为 Op (operation),一个 Op 获取 0 个或多个 tensor 执行计算,并产生 0 个或多个 tensor。

225 if __name__ =='__main__ ':
226 parser = argparse.ArgumentParser() 
.......
276 FLAGS, unparsed = parser.parse_known_ args() 
277 tf.app.run(main=main, argv=[sys.argv[O]] +unparsed) 

225 ~ 277 行,解析命令行启动 TensorFlow。

218 def main(_) : 
219 if tf.gfile.Exists(FLAGS.log_dir): 
220 tf.gfile.DeleteRecursively(FLAGS.log_dir) 
221 tf.gfile.MakeDirs (FLAGS.log_dir) 
222 run_training () 

218 ~ 222 行,启动 TensorFlow 后首先调用 main 函数,判断目录是否存在,存在就删 除不存在就创建。 最后开始训练 MNIST 数据。

 125 images_placeholder, labels_placeholder = placeholder_inputs( 
 126 FLAGS.batch_size) 

125 ~ 126 行,创建图片和其对应的标签占位符,后面真正使用时会进行数据填充,这里预先告知数据的形状和类型

129 logits = mnist.inference(images_placeholder, 
130 FLAGS.hidden1, 
131 FLAGS.hidden2) 
132
133 # Add to the Graph the Ops for loss calculation. 
134 loss = mnist.1oss(logits, labels_placeholder)  
135  
136 # Add to the Graph the Ops that calculate and apply gradients. 
137 train_op = mnist.training(loss, FLAGS.learning_rate) 
138 
139 # Add the Op to compare the logits to the labels during evaluation. 
140 eval_correct = mnist.evaluation(logits, labels_placeholder) 

129 ~ 140 行,创建网络 Op, loss Op, gradients Op, evaluation Op

142 # Build the summary Tensor based on the TF collection of Summaries. 
143 summary= tf.summary.merge_all() 

143 行,合并所有的 summary Op 为一个 Op
TensorFlow 里所有出现 summary代码的地方都是在创建 summary Op,用来保存训练过程中你想要记录的数据。 比如:

tf.summary.histogram('histogram’, var)
tf.summary.scalar(' loss', loss) 

如果你需要记录的数据很多,就会创建很多 summary Op,这时候使用 tf.summary. merge_all 来合并所有的 summary Op,就会方便很多。 在训练过程中使用 summary FileWriter 把这些数据写入磁盘。 在训练完毕后你就可以启动 Tensorboard:

tensorboard---logdir=path/to/logs

然后在浏览器中打开 Web 界面 http://localhost:6006 来查看训练中的各种指标数据。 在训练过程中的变化情况。 这里的 logdir 就是 summary File Writer 参数里填写的路径

146 init= tf.global_variables_initializer() 

146 行,创建初始化变量 Op

148 # Create a saver for writing training checkpoints. 
149 saver = tf.train.Saver() 

149 行,创建 saver 来保存模型。

152 sess = tf.Session () 

152 行,创建会话(session)上下文,图需要在会话中运行。

 154 # Instantiate a SummaryWriter to output summaries and the Graph.
 155 summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

155 行,创建 summaryFileWriter,把 summary Op 返回的数据写到磁盘。

159 # Run the Op to initialize the variables. 
160 sess.run(init) 

160 行,运行初始化所有变量,之前创建的 Op 只是描述了数据是怎样流动或者怎么计 算,没有真正开始执行运算,只有把 Op 放入 sess.run(Op)中才会开始运行

162 # Start the training loop. 
163 for step in xrange(FLAGS.max_steps): 

163 行,开始训练循环总共运行 FLAGS.max_steps 个 step

164 start_time = time.time()  

164 行,记录每个 step 的开始时间

166 # Fill a feed dictionary with the actual set of images and labels 
167 # for this particular training step. 168 feed_dict = fill_feed_dict(data_sets. train, 
169 images_placeholder, 
170 labels_placeholder) 

168 ~ 170 行,取一个 batch 训练数据,使用真实数据填充图片和标签占位符

177 _, loss_value = sess.run([train_op, loss], 
178 feed dict=feed dict)  

177 ~ 178 行,把一个 batch 数据放入模型进行训练,得到 train_op(被忽略掉了)和 loass op 的返回值,如果你想观察 Op 或者变量的值,需要把它们放到列表里传给 sess. run(),然后它们的值会以元组的形式返回

180 duration= time.time() - start_time  

180 行,计算运行一个 step 花费的时间。

183 if step % 100 == 0: 
184 # Print status to stdout. 
185 print(' Step%d:loss = %.2f (%.3f sec)'%(step,loss_ value,duration)) 
186 # Update the events file. 
187 summary_str = sess.run(summary, feed_dict=feed_dict) 
188 summary_writer.add_summary(summary_str, step) 
189 summary_writer.flush() 

183 ~ 189 行,每 100 个 step 把 summary信息写入磁盘一次

191   # Save a checkpoint and evaluate the model periodically.
192      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
193        checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
194        saver.save(sess, checkpoint_file, global_step=step)
195        # Evaluate against the training set.
196        print('Training Data Eval:')
197        do_eval(sess,
198                eval_correct,
199                images_placeholder,
200                labels_placeholder,
201                data_sets.train)
202        # Evaluate against the validation set.
203        print('Validation Data Eval:')
204        do_eval(sess,
205                eval_correct,
206                images_placeholder,
207                labels_placeholder,
208                data_sets.validation)
209        # Evaluate against the test set.
210        print('Test Data Eval:')
211        do_eval(sess,
212                eval_correct,
213                images_placeholder,
214                labels_placeholder,
215                data_sets.test)

192 ~ 215 行,每 1000 个 step 或者是最后一个 step 保存一下模型,并且打印训练过程 中产生的模型在训练、验证、测试数据集上的准确率。

MNIST.py 解析

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Builds the MNIST network.

Implements the inference/loss/training pattern for model building.

1. inference() - Builds the model as far as required for running the network
forward to make predictions.
2. loss() - Adds to the inference model the layers required to generate loss.
3. training() - Adds to the loss model the Ops required to generate and
apply gradients.

This file is used by the various "fully_connected_*.py" files and not meant to
be run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import tensorflow as tf

# The MNIST dataset has 10 classes, representing the digits 0 through 9.
NUM_CLASSES = 10

# The MNIST images are always 28x28 pixels.
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE


def inference(images, hidden1_units, hidden2_units):
  """Build the MNIST model up to where it may be used for inference.

  Args:
    images: Images placeholder, from inputs().
    hidden1_units: Size of the first hidden layer.
    hidden2_units: Size of the second hidden layer.

  Returns:
    softmax_linear: Output tensor with the computed logits.
  """
  # Hidden 1
  with tf.name_scope('hidden1'):
    weights = tf.Variable(
        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden1_units]),
                         name='biases')
    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
  # Hidden 2
  with tf.name_scope('hidden2'):
    weights = tf.Variable(
        tf.truncated_normal([hidden1_units, hidden2_units],
                            stddev=1.0 / math.sqrt(float(hidden1_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([hidden2_units]),
                         name='biases')
    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
  # Linear
  with tf.name_scope('softmax_linear'):
    weights = tf.Variable(
        tf.truncated_normal([hidden2_units, NUM_CLASSES],
                            stddev=1.0 / math.sqrt(float(hidden2_units))),
        name='weights')
    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
                         name='biases')
    logits = tf.matmul(hidden2, weights) + biases
  return logits


def loss(logits, labels):
  """Calculates the loss from the logits and the labels.

  Args:
    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
    labels: Labels tensor, int32 - [batch_size].

  Returns:
    loss: Loss tensor of type float.
  """
  labels = tf.to_int64(labels)
  return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)


def training(loss, learning_rate):
  """Sets up the training Ops.

  Creates a summarizer to track the loss over time in TensorBoard.

  Creates an optimizer and applies the gradients to all trainable variables.

  The Op returned by this function is what must be passed to the
  `sess.run()` call to cause the model to train.

  Args:
    loss: Loss tensor, from loss().
    learning_rate: The learning rate to use for gradient descent.

  Returns:
    train_op: The Op for training.
  """
  # Add a scalar summary for the snapshot loss.
  tf.summary.scalar('loss', loss)
  # Create the gradient descent optimizer with the given learning rate.
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  # Create a variable to track the global step.
  global_step = tf.Variable(0, name='global_step', trainable=False)
  # Use the optimizer to apply the gradients that minimize the loss
  # (and also increment the global step counter) as a single training step.
  train_op = optimizer.minimize(loss, global_step=global_step)
  return train_op


def evaluation(logits, labels):
  """Evaluate the quality of the logits at predicting the label.

  Args:
    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
    labels: Labels tensor, int32 - [batch_size], with values in the
      range [0, NUM_CLASSES).

  Returns:
    A scalar int32 tensor with the number of examples (out of batch_size)
    that were predicted correctly.
  """
  # For a classifier model, we can use the in_top_k Op.
  # It returns a bool tensor with shape [batch_size] that is true for
  # the examples where the label is in the top k (here k=1)
  # of all logits for that example.
  correct = tf.nn.in_top_k(logits, labels, 1)
  # Return the number of true entries.
  return tf.reduce_sum(tf.cast(correct, tf.int32))

以下为部分代码解析

38 NUM_CLASSES = 10
39 
40 # The MNIST images are always 28x28 pixels.
41 IMAGE_SIZE = 28
42 IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE

38 行, MNIST dataset 总共 10 类, 0 ~ 9 手写数字图片。
41 ~ 42 行, MNIST dataset 每张图片像素 28 × 28


45 def inference(images, hidden1_units, hidden2_units):
46  """Build the MNIST model up to where it may be used for inference.
47
48  Args:
49    images: Images placeholder, from inputs().
50    hidden1_units: Size of the first hidden layer.
51    hidden2_units: Size of the second hidden layer.
52
53  Returns:
54    softmax_linear: Output tensor with the computed logits.
55   """
56  # Hidden 1
57  with tf.name_scope('hidden1'):
58    weights = tf.Variable(
59        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
60                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
61        name='weights')
62    biases = tf.Variable(tf.zeros([hidden1_units]),
63                         name='biases')
64    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
65  # Hidden 2
66  with tf.name_scope('hidden2'):
67    weights = tf.Variable(
68        tf.truncated_normal([hidden1_units, hidden2_units],
69                            stddev=1.0 / math.sqrt(float(hidden1_units))),
70        name='weights')
71    biases = tf.Variable(tf.zeros([hidden2_units]),
72                         name='biases')
73    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
74  # Linear
75  with tf.name_scope('softmax_linear'):
76    weights = tf.Variable(
77        tf.truncated_normal([hidden2_units, NUM_CLASSES],
78                          stddev=1.0/math.sqrt(float(hidden2_units))),
79        name='weights')
80    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
81                         name='biases')
82    logits = tf.matmul(hidden2, weights) + biases
83  return logits

45 ~ 83 行,构建网络

56  # Hidden 1
57  with tf.name_scope('hidden1'):
58    weights = tf.Variable(
59        tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
60                            stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
61        name='weights')
62    biases = tf.Variable(tf.zeros([hidden1_units]),
63                         name='biases')
64    hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)

57 ~ 64 行,第一层隐藏层
57 行,设置命名空间,不同的命名空间中变量名不冲
58 ~ 61 行,创建变量和权重,使用截断正态分布进行初始化。
62 ~ 63 行,创建变量和偏置,全部初始化为 0。 64 行 (ax+b) 过激励函数 relu。 这里 w 就是权重, b 就是偏向, a 在第一层就是输入图片, 不在第一层就是上一层的输出。

65  # Hidden 2
66  with tf.name_scope('hidden2'):
67    weights = tf.Variable(
68        tf.truncated_normal([hidden1_units, hidden2_units],
69                            stddev=1.0 / math.sqrt(float(hidden1_units))),
70        name='weights')
71    biases = tf.Variable(tf.zeros([hidden2_units]),
72                         name='biases')
73    hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)

66 ~ 73 行,第二层隐藏层

74  # Linear
75  with tf.name_scope('softmax_linear'):
76    weights = tf.Variable(
77        tf.truncated_normal([hidden2_units, NUM_CLASSES],
78                          stddev=1.0/math.sqrt(float(hidden2_units))),
79        name='weights')
80    biases = tf.Variable(tf.zeros([NUM_CLASSES]),
81                         name='biases')
82    logits = tf.matmul(hidden2, weights) + biases
83  return logits

75 ~ 82 行,线性层。

def loss(logits, labels):
  """Calculates the loss from the logits and the labels.

  Args:
    logits: Logits tensor, float - [batch_size, NUM_CLASSES].
    labels: Labels tensor, int32 - [batch_size].

  Returns:
    loss: Loss tensor of type float.
  """
  labels = tf.to_int64(labels)
  return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

86 ~ 99 行,创建 loss Op。

labels = tf.to_int64(labels)

96 行,类型转换。

 return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
#以上为下载后的代码
#书中代码如下
97 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( 
98 labels=labels, logits=logits, name=’xentropy’) 
99 return tf.reduce mean(cross_entropy, name=’xentropy_mean’)

97 ~ 98 行,这里使用

sparse_softmax_cross_entropy_ with_logits loss

每张图片只允许被标记为一个类别。
99 行,取均值。

tf.summary.scalar('loss', loss)

120 行,记录 loss 变化数值。

optimizer = tf.train.GradientDescentOptimizer(learning_rate)

122 行,使用梯度下降优化器,传入学习率。

global_step = tf.Variable(0, name='global_step', trainable=False)

124 行,创建变量来记录全局步数。

train_op = optimizer.minimize(loss, global_step=global_step)

127 行,使用优化器的目的是最小化 loss。

在最后运行文件 fully_connected_ feed. py
这下就可以看到分类的结果了。 如果你愿意,当然同样可以用自己手写的数据处理成 相应的大小,然后给程序去识别
若想了解更多内容,推介《白话深度学习与TensorFlow》该书,本博客即基于该书部分内容编写。

目录

  • 采用全连接网络(Full-Connected Networks)完成MNIST数据的手写体识别
      • Tensor即张量的理解
      • fully_connected_feed.py 解析
    • MNIST.py 解析
      • 目录


你可能感兴趣的:(tensorflow,人工智能,深度学习,MNIST,手写体识别,TensorFlow,人工智能,深度学习,机器学习,MNIST)