春节闲来无事心血来潮,动了学学 TensorFlow 的念头。
踏进TF中文门户直奔tutorial。它说TF的入门功课是手写体识别mnist,相当于一般编程的‘Hello World’。
它说要用文件input_data.py下载相关数据。试了几次连不上服务器,改用鼠标点击直接下载成功。
它还说,要用input_data.py中的函数从下载的压缩文件提取数据,并把数据分为大小不同的训练用和测试验证用三种。
不过,如何具体操作它没说清楚,只好自己动手尝试。
搜到有的博文中有它的源代码,但作者说他做了改进,让我不太放心没敢用;
在 Jupyter Notebook 中执行:
import input_data
input_data.extract_images('./MNIST_data/train-images-idx3-ubyte.gz')
报错 TypeError: only integer scalar arrays can be converted to a scalar index
把相关代码拷到Notebook中:
import gzip, numpy
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
报错依旧。修改函数 extract_images 两处代码,问题解决:
buf = bytestream.read(rows[0] * cols[0] * num_images[0])
data = data.reshape(num_images[0], rows[0], cols[0], 1)
原因很简单,rows、cols、num_images 是列表不是整数。
类似的问题,在函数 extract_labels 中也有,需将 num_items 改为 num_items[0]。
从模块 input_data.py 的代码看,函数 read_data_sets 可一步完成 数据的提取、切割等任务。
于是在 Notebook 中执行:
import input_data
input_data.read_data_sets('./MNIST_data')
'MNIST_data' 是我存放数据文件的目录。
由于模块 input_data.py 中的变量rows、cols、num_images 等列表尚未修改成整数,程序运行出错。
修改变量后,重新执行上述语句,依旧报错。
另在DOS命令窗口启动 IPython,执行上列语句,也出错。
关闭IPython,重新启动后再执行,正常了不再报错。
模块代码改错后,用 import 重新加载没有用处,只能重新热启动 IPython 或 Notebook。
估计这是 IPython 的一个 bug。
虽遇上2个坑,但也有2个收获。
一是知道了全局函数调用同模块中类的办法:
class DataSet(object):
...
def read_data_sets(train_dir, fake_data=False, one_hot=False):
class DataSets(object):
pass
data_sets = DataSets()
全局函数 read_data_sets 定义了调用类的引用 data_sets;
二是弄清了在类中以修饰方式将类属性定义成函数形式:
class DataSet(object):
...
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
...
它的用法:
import input_data
data = input_data.read_data_sets('./MNIST_data')
查看 data 的结构:
dir(data)
可见有3个属性:
'test', 'train', 'validation'
查看其类型,如:
type(data.train)
可见是 input_data.DataSet。查看其属性 images
data.train.images
可见结果:
array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)