MNIST手写体识别--tensorflow

 
  

MNIST手写体识别--tensorflow

对于tensorflow给出的几个版本的手写体识别的代码进行分析。其中tensorflow的mnist代码在https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/mnist

1:softmax版本

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A very simple MNIST classifier.

See extensive documentation at
https://www.tensorflow.org/get_started/mnist/beginners
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None


def main(_):
 # Import data
 mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

 # Create the model
 x = tf.placeholder(tf.float32, [None, 784])
 W = tf.Variable(tf.zeros([784, 10]))
 b = tf.Variable(tf.zeros([10]))
 y = tf.matmul(x, W) + b

 # Define loss and optimizer
 y_ = tf.placeholder(tf.float32, [None, 10])

 # The raw formulation of cross-entropy,
 #
 # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
 # reduction_indices=[1]))
 #
 # can be numerically unstable.
 #
 # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
 # outputs of 'y', and then average across the batch.
 cross_entropy = tf.reduce_mean(
 tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

 sess = tf.InteractiveSession()
 tf.global_variables_initializer().run()
 # Train
 for _ in range(1000):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

 # Test trained model
 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 print(sess.run(accuracy, feed_dict={x: mnist.test.images,
 y_: mnist.test.labels}))

if __name__ == '__main__':
 parser = argparse.ArgumentParser()
 parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
 help='Directory for storing input data')
 FLAGS, unparsed = parser.parse_known_args()
 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

上面就是完整的代码,下面是分析。
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
这里引入的是input_data然后得到数据集,但是其实这里input_data只是一个为了更加形象化代表得到数据集的模块,input_data里面从mnist.py里面导入了read_data_sets方法,而mnist.py包含了获取手写体的所有方法。首先,mnist.py会检测是否已经下载好了数据集,根据上面的代码,可以看到默认检测的地址是/tmp/tensorflow/mnist/input_data,其实这里如果不是linux类型的机器,那么就是该驱动盘,比如说代码在e盘下面,那么文件及地址就是e:/tmp/tensorflow/mnist/input_data。如果下载好了,那么就会解压数据集,转化为image-[index, y, x, depth],label-[index]的形式,并且进行one-hot操作。每个文件都存在一定的格式,应该是 2051 num_image rows cols data的顺序,num_image是图片大小,然后行列个数。最后得到的就是train(60000,784),test(10000,784)数据集,数值基于0到1之间。这里实际上定义的就是一个二维数组,有60000行,784列。每个都包含images和labels.通过mnist.train.images和mnist.train.labels等来进行索取。labels使用one-hot进行处理的,所有如果手写体分类存在10类的话,labels就是一个10列的数组。
# Create the model
 x = tf.placeholder(tf.float32, [None, 784])
 W = tf.Variable(tf.zeros([784, 10]))
 b = tf.Variable(tf.zeros([10]))
 y = tf.matmul(x, W) + b
然后定义输入,输出以及参数W,b 可以看到,W*x数组里面每一行其实就是对于每个样本的预测,x1w1的值代表结果是0的概率,当然仅仅进行到这里值不是概率,需要经过处理之后才是概率。 定义好模型之后,需要设置目标函数来进行优化。这里使用的是一种称为交叉熵的量进行优化。交叉熵是信息论里面的知识,主要用于度量两个概率分布间的差异性信息。
cross_entropy = tf.reduce_mean(
 tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
定义好目标函数,使用梯度下降方法优化。定义好优化方法之后,只需要不断地优化直到收敛就可以停止迭代。因此,现在就可以运行了。
 sess = tf.InteractiveSession()
 tf.global_variables_initializer().run()
定义的变量必须要使用显示的初始化函数。
 for _ in range(1000):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
主要的优化过程在这里,外层循环代表的是梯度下降过程,每一次迭代相当于一步,使用的100个样本点的作用是小批梯度,因为批梯度算法每次更新使用的是全部的样本点,这样导致计算复杂度很高。 所以这里使用的是小样本来进行每次的迭代。需要说明的一点是y_其实就是标签。后面的代码用于评估模型,这里不做论述。还有一点,就是在定义x的时候使用的是[None,784],原因是后面再迭代的时候需要传入x,而使用None的话,传入任何大小都行。    

查看原文: http://www.hahaszj.top/uncategorized/mnist%e6%89%8b%e5%86%99%e4%bd%93%e8%af%86%e5%88%ab-tensorflow/186

你可能感兴趣的:(MNIST手写体识别--tensorflow)