TensorFlow里面mnist导入手写数据代码分析

TensorFlow里面mnist导入手写数据代码分析

本文主要介绍了Tensorflow(TF)手写识别,导入数据源码分析


  tensorflow/tensorflow/examples/tutorials/mnist   目录下,文件树如下:

[xzy@localhost mnist]$ tree
.
├── BUILD
├── fully_connected_feed.py
├── __init__.py
├── input_data.py
├── mnist_deep.py
├── mnist.py
├── mnist_softmax.py
├── mnist_softmax_xla.py
└── mnist_with_summaries.py

0 directories, 9 files

  fully_connected_feed.py  里面有一句代码如下:

from tensorflow.examples.tutorials.mnist import input_data
data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)    <---------调用语句,默认input_data为
                                                                                                                                                                   <---------/tmp/tensorflow/mnist/input_data

打开  input_data.py   文件

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets  <--------------------------注意这句

进入  tensorflow/contrib/learn/python/learn/datasets  ,打开  mnist.py文件,里面有个def 定义的函数

# CVDF mirror of http://yann.lecun.com/exdb/mnist/
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
... ...

def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000,
                   seed=None):
  if fake_data:

    def fake():
      return DataSet(
          [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)

    train = fake()
    validation = fake()
    test = fake()
    return base.Datasets(train=train, validation=validation, test=test)

  TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

  local_file = base.maybe_download(TRAIN_IMAGES, train_dir, SOURCE_URL + TRAIN_IMAGES)
  with open(local_file, 'rb') as f:
    train_images = extract_images(f)

  local_file = base.maybe_download(TRAIN_LABELS, train_dir, SOURCE_URL + TRAIN_LABELS)
  with open(local_file, 'rb') as f:
    train_labels = extract_labels(f, one_hot=one_hot)

  local_file = base.maybe_download(TEST_IMAGES, train_dir, SOURCE_URL + TEST_IMAGES)
  with open(local_file, 'rb') as f:
    test_images = extract_images(f)

  local_file = base.maybe_download(TEST_LABELS, train_dir, SOURCE_URL + TEST_LABELS)
   with open(local_file, 'rb') as f:
    test_labels = extract_labels(f, one_hot=one_hot)

  if not 0 <= validation_size <= len(train_images):
    raise ValueError(
        'Validation size should be between 0 and {}. Received: {}.'
        .format(len(train_images), validation_size))

  validation_images = train_images[:validation_size]
  validation_labels = train_labels[:validation_size]
  train_images = train_images[validation_size:]
  train_labels = train_labels[validation_size:]


  options = dict(dtype=dtype, reshape=reshape, seed=seed)

  train = DataSet(train_images, train_labels, **options)
  validation = DataSet(validation_images, validation_labels, **options)
  test = DataSet(test_images, test_labels, **options)

  return base.Datasets(train=train, validation=validation, test=test)



着代码里面调用了maybe_download函数下载数据,打开 tensorflow/contrib/learn/python/learn/datasets/base.py 文件

def maybe_download(filename, work_directory, source_url):
  """Download the data from source url, unless it's already here.

  Args:
      filename: string, name of the file in the directory.
      work_directory: string, path to working directory.
      source_url: url to download from if file doesn't exist.

  Returns:
      Path to resulting file.
  """
  if not gfile.Exists(work_directory):#判断工作目录不存在就创建
    gfile.MakeDirs(work_directory)
  filepath = os.path.join(work_directory, filename)
  if not gfile.Exists(filepath):#判断输入数据目录不存在就创建
    temp_file_name, _ = urlretrieve_with_retry(source_url)#直接将远程数据下载到本地目录,这是python内置的函数
    gfile.Copy(temp_file_name, filepath)#定义在tensorflow/python/lib/io/file_io.py,将数据从旧目录复制到新目录
    with gfile.GFile(filepath) as f:
      size = f.size()
    print('Successfully downloaded', filename, size, 'bytes.')
  return filepath
 




你可能感兴趣的:(tf,mnist数据加载源码分析,mnist数据加载源码)