MXNet官方文档中文版教程(5):加载数据(Iterators)

在本教程中,我们专注于如何将数据提供给训练或推断程序。 MXNet中的大多数训练和推断模块接受数据迭代器,因此简化了此过程,特别是在读取大型数据集时。 这里我们讨论 API 规则和几个提供的迭代器。

前提要求

完成本教程,我们需要:

  • MXNet
  • OpenCV Python library, Python Requests, Matplotlib & Jupyter Notebook
$ pip install opencv-python requests matplotlib jupyter
 
   
   
   
   
  • 1
  • 设置环境变量MXNET_HOME 到MXNet源码的根目录下。
$ git clone https://github.com/dmlc/mxnet ~/mxnet
$ export MXNET_HOME='~/mxnet'
 
   
   
   
   
  • 1
  • 2

MXNet数据迭代器

MXNet中的数据迭代器类似于Python迭代器对象。 在Python中,函数iter 允许通过在诸如Python列表之类的可迭代对象上调用next() 来顺序获取各项参数。 迭代器提供了一个抽象接口,用于遍历各种类型的可迭代集合,而不需要公开有关底层数据源的详细信息。

在MXNet中,数据迭代器将在每次调用next 时作为DataBatch 返回一个批次的数据。一个 DataBatch 通常包含n个训练样本及其标签。这里n是迭代器的batch_size。在数据流结束时,如果没有更多数据可读,迭代器会引发像Python iter 这样的StopIteration 异常。 DataBatch的结构定义见官网API。

可以通过DataBatch 中的provide_dataprovide_label 属性将每个训练样本及其标签上的名称,形状,类型和布局等信息作为DataDesc*数据描述符对象。

MXNet中的所有IO都通过mx.io.DataIter 及其子类来处理。 在本教程中,我们将讨论MXNet提供的一些常用的迭代器。

在深入细节之前,我们先通过导入一些所需的软件包来设置环境:

import mxnet as mx
%matplotlib inline
import os
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

从内存中读取数据

当数据存储在内存中,由NDArraynumpy ndarray 支持时,我们可以使用NDArrayIter 读取数据:

import numpy as np
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
    print([batch.data, batch.label, batch.pad])

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

从CSV文件中读取数据

MXNet提供 CSVIter 从CSV文件中读取数据,用法如下:

#lets save `data` into a csv file first and try reading it back
np.savetxt('data.csv', data, delimiter=',')
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
for batch in data_iter:
    print([batch.data, batch.pad])

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

自定义迭代器

当内置的迭代器不符合应用需求时,可以创建自己的自定义数据迭代器。

MXNet中的迭代器应该:

  1. 实现Python2中的next() 或者Python3中的__ next() __,返回DataBatch 或者到数据流的末尾时抛出StopIteration 异常。
  2. 实现reset() 方法,从头开始读取数据
  3. 具有一个provide_data 属性,包括存储了数据的名称,形状,类型和布局信息的DataDesc 对象的列表
  4. 具有一个provide_label 属性,包括存储了标签的名称,形状,类型和布局信息的DataDesc 对象的列表

当创建一个新的迭代器时,你既可以从头开始定义一个迭代器,也可以使用一个现有的迭代器。例如,在图像字幕应用中,输入样本是图像,而标签是句子。 因此,我们可以通过以下方式创建一个新的迭代器:

  • 通过使用ImageRecordIter 创建一个image_iter,它提供多线程的预取和增强。
  • 通过使用NDArrayIter 或 rnn 包中提供的bucketing 迭代器创建caption_iter
  • next() 返回image_iter.next()caption_iter.next() 的组合结果

以下示例展示如何创建一个简单的迭代器:

class SimpleIter(mx.io.DataIter):
    def __init__(self, data_names, data_shapes, data_gen,
                 label_names, label_shapes, label_gen, num_batches=10):
        self._provide_data = zip(data_names, data_shapes)
        self._provide_label = zip(label_names, label_shapes)
        self.num_batches = num_batches
        self.data_gen = data_gen
        self.label_gen = label_gen
        self.cur_batch = 0

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch < self.num_batches:
            self.cur_batch += 1
            data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
            label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
            return mx.io.DataBatch(data, label)
        else:
            raise StopIteration
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

我们可以用上面定义的SimpleIter 来训练一个简单的MLP程序:

import mxnet as mx
num_classes = 10
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')
print(net.list_arguments())
print(net.list_outputs())
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

这里有四个变量是可学习的参数:全连接层fc1fc2 的权重和偏置,输入数据的两个变量:用于训练样本的data 和包含相应的标签和softmax_output的softmax_label

MXNet的Symbol API中的data 变量称为自由变量。 要执行一个符号,它们需要绑定数据。

我们通过MXNet的module API,使用数据迭代器将样本提供给神经网络。

import logging
logging.basicConfig(level=logging.INFO)

n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
                  [lambda s: np.random.uniform(-1, 1, s)],
                  ['softmax_label'], [(n,)],
                  [lambda s: np.random.randint(0, num_classes, s)])

mod = mx.mod.Module(symbol=net)
mod.fit(data_iter, num_epoch=5)

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

RecordIO

RecordIO是MXNet用于数据IO的文件格式。 它紧凑地打包数据,以便从Hadoop HDFS和AWS S3等分布式文件系统进行高效的读写。 MXNet提供MXRecordIOMXIndexedRecordIO,用于数据的顺序访问和随机访问。

MXRecordIO

首先,我们来看一下如何使用MXRecordIO 顺序读写的例子。 这些文件以.rec扩展名命名。

record = mx.recordio.MXRecordIO('tmp.rec', 'w')
for i in range(5):
    record.write('record_%d'%i)
record.close()
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4

我们可以通过r 选项打开文件来读取数据,如下所示:

record = mx.recordio.MXRecordIO('tmp.rec', 'r')
while True:
    item = record.read()
    if not item:
        break
    print (item)
record.close()

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

MXIndexedRecordIO

MXIndexedRecordIO 支持随机或索引访问数据。 我们将创建一个索引记录文件和一个相应的索引文件,如下所示:

record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w')
for i in range(5):
    record.write_idx(i, 'record_%d'%i)
record.close()
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4

现在,我们可以使用键值访问各个记录:

record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r')
record.read_idx(3)
 
   
   
   
   
  • 1
  • 2

还可以列出文件中的所有键:

record.keys
 
   
   
   
   
  • 1

数据装包与拆包

.rec文件中的每个记录都可以包含任意的二进制数据。然而,大多数深度学习任务需要以标签/数据格式作为输入。mx.recordio 包为此类操作提供了一些实用功能,即:pack,unpack,pack_imgunpack_img

二进制数据的装包与拆包

packunpack 用于存储浮点数(或1维浮点数组)标签和二进制数据。 数据与头文件一起打包。

# pack
data = 'data'
label1 = 1.0
header1 = mx.recordio.IRHeader(flag=0, label=label1, id=1, id2=0)
s1 = mx.recordio.pack(header1, data)

label2 = [1.0, 2.0, 3.0]
header2 = mx.recordio.IRHeader(flag=3, label=label2, id=2, id2=0)
s2 = mx.recordio.pack(header2, data)
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
# unpack
print(mx.recordio.unpack(s1))
print(mx.recordio.unpack(s2))
 
   
   
   
   
  • 1
  • 2
  • 3
图像数据的装包与拆包

MXNet提供pack_imgunpack_img 来打包/解压图像数据。pack_img 打包的记录可以由mx.io.ImageRecordIter 加载。

data = np.ones((3,3,1), dtype=np.uint8)
label = 1.0
header = mx.recordio.IRHeader(flag=0, label=label, id=0, id2=0)
s = mx.recordio.pack_img(header, data, quality=100, img_fmt='.jpg')
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
# unpack_img
print(mx.recordio.unpack_img(s))
 
   
   
   
   
  • 1
  • 2
使用tools/im2rec.py

还可以使用MXNet src/tools 文件夹中提供的im2rec.py 实用脚本将原始图像转换为RecordIO 格式。 下面的Image IO部分将展示如何使用脚本转换为RecordIO格式。

Image IO

在本节中,我们将学习如何在MXNet中预处理和加载图像数据。

在MXNet中加载图像数据有4种方式。

  1. 使用 mx.image.imdecode 加载原始数据文件
  2. 使用在Python中实现的mx.img.ImageIter ,很方便自定义。 它可以从.rec(RecordIO)文件和原始图像文件读取。
  3. 使用C ++实现的MXNet后端的mx.io.ImageRecordIter 。 对于自定义不太灵活,但提供了多种语言绑定。
  4. 创建自定义的迭代器,继承mx.io.DataIter

预处理图像

预处理图像的方式有多种,我们列举其中的几种:

  • 使用mx.io.ImageRecordIter ,快速但不是很灵活。 对于像图像识别这样的简单任务来说,这是非常好的,但是对于更复杂的任务(如检测和分割)来说,不是很有用
  • 使用mx.recordio.unpack_img(或cv2.imreadskimage等)+ numpy。由于Python 全局解析锁(GIL),灵活但是缓慢。
  • 使用MXNet提供的mx.image 包。它以NDArray 格式存储图像,并利用MXNet的依赖引擎来自动并行化处理并规避GIL。

下面,我们演示一些由mx.image 包提供的常用的预处理示例。

下载我们可以使用的示例图像。

fname = mx.test_utils.download(url='http://data.mxnet.io/data/test_images.tar.gz', dirname='data', overwrite=False)
tar = tarfile.open(fname)
tar.extractall(path='./data')
tar.close()
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4

加载原始图像

mx.image.imdecode 可以加载图像。imdecode 提供了与OpenCV 类似的界面。

注意:你仍然需要安装OpenCV 来使用mx.image.imdecode(而不是CV2 Python库)。

img = mx.image.imdecode(open('data/test_images/ILSVRC2012_val_00000001.JPEG', 'rb').read())
plt.imshow(img.asnumpy()); plt.show()

 
   
   
   
   
  • 1
  • 2
  • 3

图像转换

# resize to w x h
tmp = mx.image.imresize(img, 100, 70)
plt.imshow(tmp.asnumpy()); plt.show()
 
   
   
   
   
  • 1
  • 2
  • 3
# crop a random w x h region from image
tmp, coord = mx.image.random_crop(img, (150, 200))
print(coord)
plt.imshow(tmp.asnumpy()); plt.show()
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4

使用图像迭代器加载数据

在了解如何使用两个内置Image迭代器读取数据之前,我们可以得到一个包含101个类的样本Caltech 101数据集,并将其转换为记录io格式。

下载并解压

fname = mx.test_utils.download(url='http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz', dirname='data', overwrite=False)
tar = tarfile.open(fname)
tar.extractall(path='./data')
tar.close()
 
   
   
   
   
  • 1
  • 2
  • 3
  • 4

我们来看看数据。可以看到,在根文件夹(./data/101_ObjectCategories)下,每个类别都有一个子文件夹(./ data / 101_ObjectCategories / yin_yang)。

现在让我们使用im2rec.py 脚本将它们转换成记录io格式。首先,我们需要列出包含所有图像文件及其类别的列表:

os.system('python %s/tools/im2rec.py --list=1 --recursive=1 --shuffle=1 --test-ratio=0.2 data/caltech data/101_ObjectCategories'%os.environ['MXNET_HOME'])
 
   
   
   
   
  • 1

生成的列表文件(./data/caltech_train.lst)格式为index\t(一个或多个label)\ tpath。在这种情况下,每个图像只有一个标签,但是可以修改列表以添加更多标签进行多标签训练。

然后我们可以使用此列表创建我们的记录io文件:

os.system("python %s/tools/im2rec.py --num-thread=4 --pass-through=1 data/caltech data/101_ObjectCategories"%os.environ['MXNET_HOME'])
 
   
   
   
   
  • 1

记录的io文件现在保存在这里(./data)

使用ImageRecordIter

ImageRecordIter 可用于加载以io格式保存的图像数据。要使用ImageRecordIter,只需通过加载记录文件就可以创建一个实例:

data_iter = mx.io.ImageRecordIter(
    path_imgrec="./data/caltech.rec", # the target record file
    data_shape=(3, 227, 227), # output data shape. An 227x227 region will be cropped from the original image.
    batch_size=4, # number of samples per batch
    resize=256 # resize the shorter edge to 256 before cropping
    # ... you can add more augumentation options as defined in ImageRecordIter.
    )
data_iter.reset()
batch = data_iter.next()
data = batch.data[0]
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()

 
   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
使用 ImageIter

ImageIter 是一个灵活的界面,支持以RecordIO和Raw格式加载图像。

data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227),
                              path_imgrec="./data/caltech.rec",
                              path_imgidx="./data/caltech.idx" )
data_iter.reset()
batch = data_iter.next()
data = batch.data[0]
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()
 
   
   
   
   

你可能感兴趣的:(MXNet)