mxnet学习(7):数据载入方式

1.使用ImageRecordIter读取rec

mxnet.io.ImageRecordIter(*args, **kwargs)

该方式只能从rec文件读取batches,相比于定制化的输入方式,该方式不够灵活,但是速度很快。如果要读取原图,可以使用ImageIter

eg:

data_iter = mx.io.ImageRecordIter(
  path_imgrec="./sample.rec", # The target record file.
  data_shape=(3, 227, 227), # Output data shape; 227x227 region will be cropped from the original image.
  batch_size=4, # Number of items per batch.
  resize=256 # Resize the shorter edge to 256 before cropping.
  # You can specify more augmentation options. Use help(mx.io.ImageRecordIter) to see all the options.
  )
# You can now use the data_iter to access batches of images.
batch = data_iter.next() # first batch.
images = batch.data[0] # This will contain 4 (=batch_size) images each of 3x227x227.
# process the images
...
data_iter.reset() # To restart the iterator from the beginning.

参数中可以指定augmentation的各种操作具体的参数可以参考

http://mxnet.incubator.apache.org/versions/master/api/python/io/io.html?highlight=record

1.mxnet.image.ImageIter读取rec或者原图

class mxnet.image.ImageIter(
                            batch_size,
                            data_shape, #只支持3通道RGB
                            label_width=1, 
                            path_imgrec=None,
                            path_imglist=None, 
                            path_root=None, 
                            path_imgidx=None, 
                            shuffle=False, 
                            part_index=0, 
                            num_parts=1, 
                            aug_list=None, 
                            imglist=None, 
                            data_name ='data', 
                            label_name ='softmax_label', 
                            dtype='float32', 
                            last_batch_handle='pad', 
                            **kwargs
                            )

这是一个带有大量augmentation操作的data iterator,它支持从.rec文件或者原始图片读取数据

使用path_imgrec参数load .rec文件,使用path_imglist参数load原始图片数据。

通过指定path_imgidx参数使用数据分布式训练或者shuffling

参考

http://mxnet.incubator.apache.org/versions/master/api/python/image/image.html#mxnet.image.ImageIter
https://blog.csdn.net/u014380165/article/details/74906061

一个使用的例子

data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227),
                              path_imgrec="./data/caltech.rec",
                              path_imgidx="./data/caltech.idx" )

# data_iter的类型是mxnet.image.ImageIter
#reset()函数的作用是:resents the iterator to the beginning of the data
data_iter.reset()

#batch的类型是mxnet.io.DataBatch,因为next()方法的返回值就是DataBatch
batch = data_iter.next()

#data是一个NDArray,表示第一个batch中的数据,因为这里的batch_size大小是4,所以data的size是4*3*227*227
data = batch.data[0]

#这个for循环就是读取这个batch中的每张图像并显示
for i in range(4):
    plt.subplot(1,4,i+1)
    plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()

使用mx.image.CreateAugmenter()进行图像augmentation

train = mx.image.ImageIter(
        batch_size            = args.batch_size,
        data_shape          = (3,224,224),
        label_width           = 1,
        path_imglist          = args.data_train,
        path_root              = args.image_train,
        part_index            = rank,
        shuffle                  = True,
        data_name           = 'data',
        label_name           = 'softmax_label',
        aug_list                 = mx.image.CreateAugmenter((3,224,224),resize=224,rand_crop=True,rand_mirror=True,mean=True))

image.CreateAugmenter相关的设置和参数

image.CreateAugmenter(
                data_shape,
                resize=0,
                rand_crop=False,
                rand_resize=False,
                rand_mirror=False,
                mean=None,#这里如果是True,默认imagenet的均值
                std=None,#同上
                brightness=0,
                contrast=0,
                saturation=0,
                hue=0,
                pca_noise=0,
                rand_gray=0,
                inter_method=2
                )
#Creates an augmenter list.

Parameters:

  • data_shape (tuple of int) – Shape for output data
  • resize (int) – Resize shorter edge if larger than 0 at the begining
  • rand_crop (bool) – Whether to enable random cropping other than center crop
  • rand_resize (bool) – Whether to enable random sized cropping, require rand_crop to be enabled
  • rand_gray (float) – [0, 1], probability to convert to grayscale for all channels, the number of channels will not be reduced to 1
  • rand_mirror (bool) – Whether to apply horizontal flip to image with probability 0.5
  • mean (np.ndarray or None) – Mean pixel values for [r, g, b]
  • std (np.ndarray or None) – Standard deviations for [r, g, b]
  • brightness (float) – Brightness jittering range (percent)
  • contrast (float) – Contrast jittering range (percent)
  • saturation (float) – Saturation jittering range (percent)
  • hue (float) – Hue jittering range (percent)
  • pca_noise (float) – Pca noise level (percent)
  • inter_method (int, default=2(Area-based)) –
    Interpolation method for all resizing operations
    Possible values: 0: Nearest Neighbors Interpolation. 1: Bilinear interpolation. 2: Area-based (resampling using pixel area relation). It may be a preferred method for image decimation, as it gives moire-free results. But when the image is zoomed, it is similar to the Nearest Neighbors method. (used by default). 3: Bicubic interpolation over 4x4 pixel neighborhood. 4: Lanczos interpolation over 8x8 pixel neighborhood. 9: Cubic for enlarge, area for shrink, bilinear for others 10: Random select from interpolation method metioned above. Note: When shrinking an image, it will generally look best with AREA-based interpolation, whereas, when enlarging an image, it will generally look best with Bicubic (slow) or Bilinear (faster but still looks OK).

3.使用Dataset和DataLoader

gluon中提供了一种使用dataset和DataLoader载入数据的方式,这种载入数据方式与pytorch十分相似。

参考:https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/datasets.html

Dataset对象用于表示数据集合以及载入和解析数据的方法。gluon中有许多不同的Dataset类,下面用ArrayDataset进行示范说明。

import mxnet as mx
import os
import tarfile

mx.random.seed(42) # Fix the seed for reproducibility
X = mx.random.uniform(shape=(10, 3))
y = mx.random.uniform(shape=(10, 1))
dataset = mx.gluon.data.dataset.ArrayDataset(X, y)

Dataset的重要特点之一就是可以根据一个index检索到对应的sample。

sample_idx = 4
sample = dataset[sample_idx]

assert len(sample) == 2
assert sample[0].shape == (3, )#data
assert sample[1].shape == (1, )#label
print(sample)

但是我们通常不会直接使用索引对Dataset进行检索,而是使用DataLoader

DataLoader被用来从Dataset中建立一个mini-batch,并提供一个方便的迭代器接口,作为batch的循环。其重要的参数是batch_size

DataLoader另外一个优点是可以使用多线程来载入数据,参数num_workers

from multiprocessing import cpu_count
CPU_COUNT = cpu_count()

data_loader = mx.gluon.data.DataLoader(dataset, batch_size=5, num_workers=CPU_COUNT)

for X_batch, y_batch in data_loader:
    print("X_batch has shape {}, and y_batch has shape {}".format(X_batch.shape, y_batch.shape))

当datset中的所有样本都做为batch的一个样本返回之后,loader的循环就会停止。有时候dataset中的样本数不能被batch_size整除,默认情况下是最后一个循环返回一个比batch_size小的batch,也可以指定last_batch参数为discard(忽略最后一个batch),或者rollover(下一个epoch从剩余的samples开始)

使用Dataset加载自定义数据

gluon中有许多的Dataset类,其中mxnet.gluon.data.vision.datasets.ImageFolderDatset直接从用户定义的文件夹中加载数据,并且推断其label(class)。

使用该类必须将不同label的图片放在不同的文件夹下面

train_dataset = mx.gluon.data.vision.datasets.ImageFolderDataset(training_path)
test_dataset = mx.gluon.data.vision.datasets.ImageFolderDataset(testing_path)

有一个直接读取rec的Dataset类

class mxnet.gluon.data.vision.datasets.ImageRecordDataset(filename, flag=1, transform=None)

A dataset wrapping over a RecordIO file containing images.
Each sample is an image and its corresponding label.

Parameters:

  • filename (str) – Path to rec file.
  • flag ({0, 1}, default 1) – If 0, always convert images to greyscale. If 1, always convert images to colored (RGB).
  • transform (function, default None) –
    A user defined callback that transforms each sample. For example:
transform=lambda data, label: (data.astype(np.float32)/255, label)

此外也可以通过自定义Dataset类的方式来载入数据,如下节。

4.自定义Dataset载入方式

官方文档中提供了一种定义custom dataset的数据载入方式,这种方式方便灵活,可以根据需求自己修改。

参考 https://mxnet.incubator.apache.org/versions/master/tutorials/python/data_augmentation_with_masks.html

根据参考文档中的内容,如果需要根据一个list读取原始图片,该list每行第一列是图片路径,第二列是图片label。那么可以参考下面的代码

import mxnet as mx
from mxnet.gluon.data import dataset
from mxnet.gluon.data.vision import datasets, transforms
from mxnet import gluon, nd
import os
import cv2
import time
class readImageFromList(dataset.Dataset):
    def __init__(self, image_path, text_file, transform = None):
        self._transform = transform
        self._image_path = image_path
        self._text_file = text_file
        self._images = [line.strip("\n").split("\t")[0] for line in open(self._text_file, "r")]
        self._labels = [line.strip("\n").split("\t")[1] for line in open(self._text_file, "r")]
    def __getitem__(self, idx):
        file_name = os.path.join(self._image_path, self._images[idx])
        if os.path.isfile(file_name):
            image = mx.image.imread(file_name)
            #image = nd.random.uniform(shape = (3, 256, 256))
        else:
            print(file_name + "cannot found.")
        label = int(self._labels[idx])#这里是否需要转化为tensor
        label = nd.array([label])
        if self._transform is not None:
            return self._transform(image), label
        else:
            return image, label
    def __len__(self):
        return len(self._images)

class imageTransform():
    def __init__(self):
        self.resize = mx.image.ResizeAug(256)
        self.crop = mx.image.RandomCropAug((224, 224))
        self.flip = mx.image.HorizontalFlipAug(p = 0.5)
        self.cast = mx.image.CastAug(typ = 'float32')
        self.bright = mx.image.BrightnessJitterAug(0.1)
        self.contrast = mx.image.ContrastJitterAug(0.1)
        self.color = mx.image.ColorJitterAug(0.1, 0.1, 0.1)
        self.rgb_mean = nd.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
        self.rgb_std = nd.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
    def __call__(self, image):
        img = self.resize(image)
        img = self.crop(img)
        img = self.flip(img)
        img = self.cast(img)
        img = self.color(img)
        img = img.transpose((2, 0, 1))
        img = (img.astype('float32') / 255 - self.rgb_mean) / self.rgb_std
        return img
if __name__ == "__main__":
    #transformer = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.13, 0.31)])
    transformer = imageTransform()
    image_train = readImageFromList(image_path = "dogvscat/train/", text_file = "train_list.txt", transform = transformer)
    batch_size = 128
    train_data = gluon.data.DataLoader(image_train, batch_size = batch_size, shuffle = True, num_workers = 1)
    for data, label in train_data:
        print(data.shape)
        print(label.shape)
        break

注意图片的尺度要一致

5.使用NDArrayIter

暂时还没有试过,但是NDArrayIter只能单线程,而ImageIter可以多线程。

import mxnet as mx 
import numpy as np 
import random 

batch_size = 5
dataset_length = 50

# random seeds
random.seed(1)
np.random.seed(1)
mx.random.seed(1)

train_data = np.random.rand(dataset_length, 28,28).astype('float32')
train_label = np.random.randint(0, 10, (dataset_length,)).astype('float32')

data_iter = mx.io.NDArrayIter(data=train_data, label=train_label, batch_size=batch_size, shuffle=False, data_name='data', label_name='softmax_label')
for batch in data_iter:
    print(batch.data[0].shape, batch.label[0])
    break

Appendix

从上述几种载入数据的方式可以看到,载入方式主要分为两种

  • DataIter的传统方式,返回DataBatch,有data和label两个属性的array。
  • Dataset + DataLoader的gluon方式,返回(data, label)的tuple

但是DataIter得到的数据无法直接用于DataLoader。使用gluon的时候推荐将DataIter转换为DatLoader可以加载的方式,但是augumentation这些操作不用太过在意(可以在DataIter中完成)。

一个简单的类可以将DataIter对象打包成典型的gluon循环可以使用的类型。可以将该类对mxnet.image.ImageItermxnet.io.ImageRecordIter等对象使用。

class DataIterLoader():
    def __init__(self, data_iter):
        self.data_iter = data_iter

    def __iter__(self):
        self.data_iter.reset()
        return self

    def __next__(self):
        batch = self.data_iter.__next__()
        assert len(batch.data) == len(batch.label) == 1
        data = batch.data[0]
        label = batch.label[0]
        return data, label

    def next(self):
        return self.__next__() # for Python 2
data_iter = mx.io.NDArrayIter(data=X, label=y, batch_size=5)
data_iter_loader = DataIterLoader(data_iter)
for X_batch, y_batch in data_iter_loader:
    assert X_batch.shape == (5, 3)
    assert y_batch.shape == (5, 1)

你可能感兴趣的:(mxnet)