TensorFlow 起步就掉进了坑

春节闲来无事心血来潮,动了学学 TensorFlow 的念头。

踏进TF中文门户直奔tutorial。它说TF的入门功课是手写体识别mnist,相当于一般编程的‘Hello World’。

它说要用文件input_data.py下载相关数据。试了几次连不上服务器,改用鼠标点击直接下载成功。

它还说,要用input_data.py中的函数从下载的压缩文件提取数据,并把数据分为大小不同的训练用和测试验证用三种。

不过,如何具体操作它没说清楚,只好自己动手尝试。

寻找完整版的input_data.py

原产地Google的网点需要才能进去,不费那个劲了;

搜到有的博文中有它的源代码,但作者说他做了改进,让我不太放心没敢用;

找了半天,在github上找到TF中文版中有input_data.py。

遇一小坑

在 Jupyter Notebook 中执行:

import input_data
input_data.extract_images('./MNIST_data/train-images-idx3-ubyte.gz')

报错 TypeError: only integer scalar arrays can be converted to a scalar index

把相关代码拷到Notebook中:

import gzip, numpy

def _read32(bytestream):
  dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  return numpy.frombuffer(bytestream.read(4), dtype=dt)

def extract_images(filename):
  """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  print('Extracting', filename)
  with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError(
          'Invalid magic number %d in MNIST image file: %s' %
          (magic, filename))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data

报错依旧。修改函数 extract_images 两处代码,问题解决:

    buf = bytestream.read(rows[0] * cols[0] * num_images[0])
    data = data.reshape(num_images[0], rows[0], cols[0], 1)

原因很简单,rows、cols、num_images 是列表不是整数。

类似的问题,在函数 extract_labels 中也有,需将 num_items 改为 num_items[0]。

再遇一坑

从模块 input_data.py 的代码看,函数 read_data_sets 可一步完成 数据的提取、切割等任务。

于是在 Notebook 中执行:

import input_data
input_data.read_data_sets('./MNIST_data')

'MNIST_data' 是我存放数据文件的目录。

由于模块 input_data.py 中的变量rows、cols、num_images 等列表尚未修改成整数,程序运行出错。

修改变量后,重新执行上述语句,依旧报错。

另在DOS命令窗口启动 IPython,执行上列语句,也出错。

关闭IPython,重新启动后再执行,正常了不再报错。

模块代码改错后,用 import 重新加载没有用处,只能重新热启动 IPython 或 Notebook。

估计这是 IPython 的一个 bug。

两个收获

虽遇上2个坑,但也有2个收获。

一是知道了全局函数调用同模块中类的办法:

class DataSet(object):
...
def read_data_sets(train_dir, fake_data=False, one_hot=False):
  class DataSets(object):
    pass
  data_sets = DataSets()

全局函数  read_data_sets 定义了调用类的引用 data_sets;

二是弄清了在类中以修饰方式将类属性定义成函数形式:

class DataSet(object):
  ...
  @property
  def images(self):
    return self._images

  @property
  def labels(self):
    return self._labels
  ...

它的用法:

import input_data
data = input_data.read_data_sets('./MNIST_data')

查看 data 的结构:

dir(data)
可见有3个属性:
 'test',
 'train',
 'validation'

查看其类型,如:

type(data.train)

可见是 input_data.DataSet。查看其属性 images

data.train.images

可见结果:

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

你可能感兴趣的:(TensorFlow 起步就掉进了坑)