教程地址:TensorFlow中文社区
源码: tensorflow/g3doc/tutorials/mnist/
本教程的目标是展示如何下载用于手写数字分类问题所要用到的(经典)MNIST数据集。
本教程需要使用以下文件:
文件 | 目的 |
---|---|
input_data.py |
下载用于训练和测试的MNIST数据集的源码 |
备注:
input_data.py 文件路径为:tensorflow\examples\tutorials\mnist,
内容为:
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
# pylint: enable=unused-import
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
你会发现,该文件主要引用该目录下tensorflow\contrib\learn\python\learn\datasets\的mnist.py文件里面的read_data_sets函数
该目录结构:
MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.
更多详情, 请参考 Yann LeCun's MNIST page 或 Chris Olah's visualizations of MNIST.
Yann LeCun's MNIST page 也提供了训练集与测试集数据的下载。
文件 | 内容 |
---|---|
train-images-idx3-ubyte.gz |
训练集图片 - 55000 张 训练图片, 5000 张 验证图片 |
train-labels-idx1-ubyte.gz |
训练集图片对应的数字标签 |
t10k-images-idx3-ubyte.gz |
测试集图片 - 10000 张 图片 |
t10k-labels-idx1-ubyte.gz |
测试集图片对应的数字标签 |
在 input_data.py
文件中, maybe_download()
函数可以确保这些训练数据下载到本地文件夹中。
文件夹的名字在 fully_connected_feed.py
文件的顶部由一个标记变量指定,你可以根据自己的需要进行修改。
这些文件本身并没有使用标准的图片格式储存,并且需要使用input_data.py
文件中extract_images()
和extract_labels()
函数来手动解压(页面中有相关说明)。
图片数据将被解压成2维的tensor:[image index, pixel index]
其中每一项表示某一图片中特定像素的强度值, 范围从 [0, 255]
到 [-0.5, 0.5]
。 "image index"代表数据集中图片的编号, 从0到数据集的上限值。"pixel index"代表该图片中像素点得个数, 从0到图片的像素上限值。
以train-*
开头的文件中包括60000个样本,其中分割出55000个样本作为训练集,其余的5000个样本作为验证集。因为所有数据集中28x28像素的灰度图片的尺寸为784,所以训练集输出的tensor格式为[55000, 784]
。
数字标签数据被解压称1维的tensor: [image index]
,它定义了每个样本数值的类别分类。对于训练集的标签来说,这个数据规模就是:[55000]
。
底层的源码将会执行下载、解压、重构图片和标签数据来组成以下的数据集对象:
数据集 | 目的 |
---|---|
data_sets.train |
55000 组 图片和标签, 用于训练。 |
data_sets.validation |
5000 组 图片和标签, 用于迭代验证训练的准确性。 |
data_sets.test |
10000 组 图片和标签, 用于最终测试训练的准确性。 |
执行read_data_sets()
函数将会返回一个DataSet
实例,其中包含了以上三个数据集。函数DataSet.next_batch()
是用于获取以batch_size
为大小的一个元组,其中包含了一组图片和标签,该元组会被用于当前的TensorFlow运算会话中。
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)
在TensorFlow的源码中,MNIST数据集的读取操作在contrib\learn\python\learn\datasets\data\mnist.py中,函数是read_data_sets。
read_data_sets函数:
def read_data_sets(train_dir,
fake_data=False,
one_hot=False,
ype=dtypes.float32,
reshape=True,
validation_size=5000):
train_dir:为数据集在文件夹的位置,在这里为tensorflow\examples\tutorials\mnist\MNIST_data;
fake_data: 在官方教程中提到fake_data标记是用于单元测试的,读者可以不必理会;
one_hot:为one_hot编码,即独热码,作用是将状态值编码成状态向量,例如,数字状态共有0~9这10种,对于数字7,将它进行one_hot编码后为[0 0 0 0 0 0 0 1 0 0],这样使得状态对于计算机来说更加明确,对于矩阵操作也更加高效。
dtype:的作用是将图像像素点的灰度值从[0, 255]转变为[0.0, 1.0]。
reshape:的作用是将图像的形状从[num examples, rows, columns, depth]转变为[num examples, rows*columns] (对于二维图片,depth为1)。
validation_size:即为从训练集中抽取这么多来作为验证集。
变量定义好之后,接下来提取数据集。
with open(local_file, 'rb') as f:
train_images = extract_images(f)
看extract_images函数:
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST image file: %s' %
(magic, f.name))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
如果这么看代码可能很难理解,但是如果清楚MNIST数据集文件的结构之后就好理解得多,对于MNIST的images文件:
offset | type | value | description |
0000 | 32 bit integer | 0x00000803(2051) | magic number |
0004 | 32 bit integer | 60000 | number of images |
0008 | 32 bit integer | 28 | number of rows |
0012 | 32 bit integer | 28 | number of columns |
0016 | unsigned byte | ?? | pixel |
0017 | unsigned byte | ?? | pixel |
0018 | unsigned byte | ?? | pixel |
...... | |||
xxxx | unsigned byte | ?? | pixel |
代码中_read32()的作用是从文件流中动态读取4位数据并转换为uint32的数据。
image文件的前四位为魔术码(magic number),只有检测到这4位数据的值和2051相等时,才代表这是正确的image文件,才会继续往下读取。接下来继续读取之后的4位,代表着image文件中,所包含的图片的数量(num_images)。再接着读4位,为每一幅图片的行数(rows),再后4位,为每一幅图片的列数(cols)。最后再读接下来的rows * cols * num_images位,即为所有图片的像素值。最后再将读取到的所有像素值装换为[index, rows, cols, depth]的4D矩阵。这样就将全部的image数据读取了出来。
同理,对于MNIST的labels文件:
offset | type | value | description |
0000 | 32 bit integer | 0x00000801(2049) | magic number |
0004 | 32 bit integer | 60000 | number of items |
0008 | unsigned byte | ?? | label |
0009 | unsigned byte | ?? | label |
...... | |||
xxxx | unsigned byte | ?? | label |
再看代码:
def extract_labels(f, one_hot=False, num_classes=10):
"""Extract the labels into a 1D uint8 numpy array [index].
Args:
f: A file object that can be passed into a gzip reader.
one_hot: Does one hot encoding for the result.
num_classes: Number of classes for the one hot encoding.
Returns:
labels: a 1D uint8 numpy array.
Raises:
ValueError: If the bystream doesn't start with 2049.
"""
print('Extracting', f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST label file: %s' %
(magic, f.name))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
if one_hot:
return dense_to_one_hot(labels, num_classes)
return labels
同样的也是依次读取文件的魔术码以及标签总数,最后把所有图片的标签读取出来,成一个长度为num_items的1D的向量。不过代码中还有一个one_hot的部分,dense_to_one_hot的代码为:
def dense_to_one_hot(labels_dense, num_classes):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
正如文章开头提到one_hot的作用,这里将1D向量中的每一个值,编码成一个长度为num_classes的向量,向量中对应于该值的位置为1,其余为0,所以one_hot将长度为num_labels的向量编码为一个[num_labels, num_classes]的2D矩阵。
以上就是如何将MNIST数据文件中的images和labels分别提取出来的过程。
备注:
以上函数都有,“@deprecated(None, 'Please use tf.data to implement this functionality.')”。
以后的新版本估计将没有这些函数。