本文主要介绍了Tensorflow(TF)手写识别,导入数据源码分析
在 tensorflow/tensorflow/examples/tutorials/mnist 目录下,文件树如下:
[xzy@localhost mnist]$ tree
.
├── BUILD
├── fully_connected_feed.py
├── __init__.py
├── input_data.py
├── mnist_deep.py
├── mnist.py
├── mnist_softmax.py
├── mnist_softmax_xla.py
└── mnist_with_summaries.py
0 directories, 9 files
在 fully_connected_feed.py 里面有一句代码如下:
from tensorflow.examples.tutorials.mnist import input_datadata_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data) <---------调用语句,默认input_data为<---------/tmp/tensorflow/mnist/input_data
打开 input_data.py 文件
from __future__ import absolute_import from __future__ import division from __future__ import print_function import gzip import os import tempfile import numpy from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets <--------------------------注意这句
进入 tensorflow/contrib/learn/python/learn/datasets ,打开 mnist.py文件,里面有个def 定义的函数
# CVDF mirror of http://yann.lecun.com/exdb/mnist/ SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' ... ... def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=dtypes.float32, reshape=True, validation_size=5000, seed=None): if fake_data: def fake(): return DataSet( [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed) train = fake() validation = fake() test = fake() return base.Datasets(train=train, validation=validation, test=test) TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz' local_file = base.maybe_download(TRAIN_IMAGES, train_dir, SOURCE_URL + TRAIN_IMAGES) with open(local_file, 'rb') as f: train_images = extract_images(f) local_file = base.maybe_download(TRAIN_LABELS, train_dir, SOURCE_URL + TRAIN_LABELS) with open(local_file, 'rb') as f: train_labels = extract_labels(f, one_hot=one_hot) local_file = base.maybe_download(TEST_IMAGES, train_dir, SOURCE_URL + TEST_IMAGES) with open(local_file, 'rb') as f: test_images = extract_images(f) local_file = base.maybe_download(TEST_LABELS, train_dir, SOURCE_URL + TEST_LABELS) with open(local_file, 'rb') as f: test_labels = extract_labels(f, one_hot=one_hot) if not 0 <= validation_size <= len(train_images): raise ValueError( 'Validation size should be between 0 and {}. Received: {}.' .format(len(train_images), validation_size)) validation_images = train_images[:validation_size] validation_labels = train_labels[:validation_size] train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] options = dict(dtype=dtype, reshape=reshape, seed=seed) train = DataSet(train_images, train_labels, **options) validation = DataSet(validation_images, validation_labels, **options) test = DataSet(test_images, test_labels, **options) return base.Datasets(train=train, validation=validation, test=test)
着代码里面调用了maybe_download函数下载数据,打开 tensorflow/contrib/learn/python/learn/datasets/base.py 文件
def maybe_download(filename, work_directory, source_url): """Download the data from source url, unless it's already here. Args: filename: string, name of the file in the directory. work_directory: string, path to working directory. source_url: url to download from if file doesn't exist. Returns: Path to resulting file. """ if not gfile.Exists(work_directory):#判断工作目录不存在就创建 gfile.MakeDirs(work_directory) filepath = os.path.join(work_directory, filename) if not gfile.Exists(filepath):#判断输入数据目录不存在就创建 temp_file_name, _ = urlretrieve_with_retry(source_url)#直接将远程数据下载到本地目录,这是python内置的函数 gfile.Copy(temp_file_name, filepath)#定义在tensorflow/python/lib/io/file_io.py,将数据从旧目录复制到新目录 with gfile.GFile(filepath) as f: size = f.size() print('Successfully downloaded', filename, size, 'bytes.') return filepath