mxnet.io.ImageRecordIter(*args, **kwargs)
该方式只能从rec文件读取batches,相比于定制化的输入方式,该方式不够灵活,但是速度很快。如果要读取原图,可以使用ImageIter
eg:
data_iter = mx.io.ImageRecordIter(
path_imgrec="./sample.rec", # The target record file.
data_shape=(3, 227, 227), # Output data shape; 227x227 region will be cropped from the original image.
batch_size=4, # Number of items per batch.
resize=256 # Resize the shorter edge to 256 before cropping.
# You can specify more augmentation options. Use help(mx.io.ImageRecordIter) to see all the options.
)
# You can now use the data_iter to access batches of images.
batch = data_iter.next() # first batch.
images = batch.data[0] # This will contain 4 (=batch_size) images each of 3x227x227.
# process the images
...
data_iter.reset() # To restart the iterator from the beginning.
参数中可以指定augmentation的各种操作具体的参数可以参考
http://mxnet.incubator.apache.org/versions/master/api/python/io/io.html?highlight=record
class mxnet.image.ImageIter(
batch_size,
data_shape, #只支持3通道RGB
label_width=1,
path_imgrec=None,
path_imglist=None,
path_root=None,
path_imgidx=None,
shuffle=False,
part_index=0,
num_parts=1,
aug_list=None,
imglist=None,
data_name ='data',
label_name ='softmax_label',
dtype='float32',
last_batch_handle='pad',
**kwargs
)
这是一个带有大量augmentation操作的data iterator,它支持从.rec文件或者原始图片读取数据
使用path_imgrec
参数load .rec
文件,使用path_imglist
参数load原始图片数据。
通过指定path_imgidx
参数使用数据分布式训练或者shuffling
参考
http://mxnet.incubator.apache.org/versions/master/api/python/image/image.html#mxnet.image.ImageIter
https://blog.csdn.net/u014380165/article/details/74906061
一个使用的例子
data_iter = mx.image.ImageIter(batch_size=4, data_shape=(3, 227, 227),
path_imgrec="./data/caltech.rec",
path_imgidx="./data/caltech.idx" )
# data_iter的类型是mxnet.image.ImageIter
#reset()函数的作用是:resents the iterator to the beginning of the data
data_iter.reset()
#batch的类型是mxnet.io.DataBatch,因为next()方法的返回值就是DataBatch
batch = data_iter.next()
#data是一个NDArray,表示第一个batch中的数据,因为这里的batch_size大小是4,所以data的size是4*3*227*227
data = batch.data[0]
#这个for循环就是读取这个batch中的每张图像并显示
for i in range(4):
plt.subplot(1,4,i+1)
plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1,2,0)))
plt.show()
使用mx.image.CreateAugmenter()
进行图像augmentation
train = mx.image.ImageIter(
batch_size = args.batch_size,
data_shape = (3,224,224),
label_width = 1,
path_imglist = args.data_train,
path_root = args.image_train,
part_index = rank,
shuffle = True,
data_name = 'data',
label_name = 'softmax_label',
aug_list = mx.image.CreateAugmenter((3,224,224),resize=224,rand_crop=True,rand_mirror=True,mean=True))
image.CreateAugmenter相关的设置和参数
image.CreateAugmenter(
data_shape,
resize=0,
rand_crop=False,
rand_resize=False,
rand_mirror=False,
mean=None,#这里如果是True,默认imagenet的均值
std=None,#同上
brightness=0,
contrast=0,
saturation=0,
hue=0,
pca_noise=0,
rand_gray=0,
inter_method=2
)
#Creates an augmenter list.
Parameters:
gluon中提供了一种使用dataset和DataLoader载入数据的方式,这种载入数据方式与pytorch十分相似。
参考:https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/datasets.html
Dataset
对象用于表示数据集合以及载入和解析数据的方法。gluon中有许多不同的Dataset
类,下面用ArrayDataset
进行示范说明。
import mxnet as mx
import os
import tarfile
mx.random.seed(42) # Fix the seed for reproducibility
X = mx.random.uniform(shape=(10, 3))
y = mx.random.uniform(shape=(10, 1))
dataset = mx.gluon.data.dataset.ArrayDataset(X, y)
Dataset
的重要特点之一就是可以根据一个index检索到对应的sample。
sample_idx = 4
sample = dataset[sample_idx]
assert len(sample) == 2
assert sample[0].shape == (3, )#data
assert sample[1].shape == (1, )#label
print(sample)
但是我们通常不会直接使用索引对Dataset
进行检索,而是使用DataLoader
DataLoader
被用来从Dataset
中建立一个mini-batch,并提供一个方便的迭代器接口,作为batch的循环。其重要的参数是batch_size
。
DataLoader
另外一个优点是可以使用多线程来载入数据,参数num_workers
。
from multiprocessing import cpu_count
CPU_COUNT = cpu_count()
data_loader = mx.gluon.data.DataLoader(dataset, batch_size=5, num_workers=CPU_COUNT)
for X_batch, y_batch in data_loader:
print("X_batch has shape {}, and y_batch has shape {}".format(X_batch.shape, y_batch.shape))
当datset中的所有样本都做为batch的一个样本返回之后,loader的循环就会停止。有时候dataset中的样本数不能被batch_size整除,默认情况下是最后一个循环返回一个比batch_size小的batch,也可以指定last_batch
参数为discard
(忽略最后一个batch),或者rollover
(下一个epoch从剩余的samples开始)
使用Dataset加载自定义数据
gluon中有许多的Dataset类,其中mxnet.gluon.data.vision.datasets.ImageFolderDatset
直接从用户定义的文件夹中加载数据,并且推断其label(class)。
使用该类必须将不同label的图片放在不同的文件夹下面
train_dataset = mx.gluon.data.vision.datasets.ImageFolderDataset(training_path)
test_dataset = mx.gluon.data.vision.datasets.ImageFolderDataset(testing_path)
有一个直接读取rec的Dataset类
class mxnet.gluon.data.vision.datasets.ImageRecordDataset(filename, flag=1, transform=None)
A dataset wrapping over a RecordIO file containing images.
Each sample is an image and its corresponding label.
Parameters:
transform=lambda data, label: (data.astype(np.float32)/255, label)
此外也可以通过自定义Dataset
类的方式来载入数据,如下节。
官方文档中提供了一种定义custom dataset的数据载入方式,这种方式方便灵活,可以根据需求自己修改。
参考 https://mxnet.incubator.apache.org/versions/master/tutorials/python/data_augmentation_with_masks.html
根据参考文档中的内容,如果需要根据一个list读取原始图片,该list每行第一列是图片路径,第二列是图片label。那么可以参考下面的代码
import mxnet as mx
from mxnet.gluon.data import dataset
from mxnet.gluon.data.vision import datasets, transforms
from mxnet import gluon, nd
import os
import cv2
import time
class readImageFromList(dataset.Dataset):
def __init__(self, image_path, text_file, transform = None):
self._transform = transform
self._image_path = image_path
self._text_file = text_file
self._images = [line.strip("\n").split("\t")[0] for line in open(self._text_file, "r")]
self._labels = [line.strip("\n").split("\t")[1] for line in open(self._text_file, "r")]
def __getitem__(self, idx):
file_name = os.path.join(self._image_path, self._images[idx])
if os.path.isfile(file_name):
image = mx.image.imread(file_name)
#image = nd.random.uniform(shape = (3, 256, 256))
else:
print(file_name + "cannot found.")
label = int(self._labels[idx])#这里是否需要转化为tensor
label = nd.array([label])
if self._transform is not None:
return self._transform(image), label
else:
return image, label
def __len__(self):
return len(self._images)
class imageTransform():
def __init__(self):
self.resize = mx.image.ResizeAug(256)
self.crop = mx.image.RandomCropAug((224, 224))
self.flip = mx.image.HorizontalFlipAug(p = 0.5)
self.cast = mx.image.CastAug(typ = 'float32')
self.bright = mx.image.BrightnessJitterAug(0.1)
self.contrast = mx.image.ContrastJitterAug(0.1)
self.color = mx.image.ColorJitterAug(0.1, 0.1, 0.1)
self.rgb_mean = nd.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
self.rgb_std = nd.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
def __call__(self, image):
img = self.resize(image)
img = self.crop(img)
img = self.flip(img)
img = self.cast(img)
img = self.color(img)
img = img.transpose((2, 0, 1))
img = (img.astype('float32') / 255 - self.rgb_mean) / self.rgb_std
return img
if __name__ == "__main__":
#transformer = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.13, 0.31)])
transformer = imageTransform()
image_train = readImageFromList(image_path = "dogvscat/train/", text_file = "train_list.txt", transform = transformer)
batch_size = 128
train_data = gluon.data.DataLoader(image_train, batch_size = batch_size, shuffle = True, num_workers = 1)
for data, label in train_data:
print(data.shape)
print(label.shape)
break
注意图片的尺度要一致
暂时还没有试过,但是NDArrayIter只能单线程,而ImageIter可以多线程。
import mxnet as mx
import numpy as np
import random
batch_size = 5
dataset_length = 50
# random seeds
random.seed(1)
np.random.seed(1)
mx.random.seed(1)
train_data = np.random.rand(dataset_length, 28,28).astype('float32')
train_label = np.random.randint(0, 10, (dataset_length,)).astype('float32')
data_iter = mx.io.NDArrayIter(data=train_data, label=train_label, batch_size=batch_size, shuffle=False, data_name='data', label_name='softmax_label')
for batch in data_iter:
print(batch.data[0].shape, batch.label[0])
break
从上述几种载入数据的方式可以看到,载入方式主要分为两种
但是DataIter得到的数据无法直接用于DataLoader。使用gluon的时候推荐将DataIter转换为DatLoader可以加载的方式,但是augumentation这些操作不用太过在意(可以在DataIter中完成)。
一个简单的类可以将DataIter对象打包成典型的gluon循环可以使用的类型。可以将该类对mxnet.image.ImageIter
和mxnet.io.ImageRecordIter
等对象使用。
class DataIterLoader():
def __init__(self, data_iter):
self.data_iter = data_iter
def __iter__(self):
self.data_iter.reset()
return self
def __next__(self):
batch = self.data_iter.__next__()
assert len(batch.data) == len(batch.label) == 1
data = batch.data[0]
label = batch.label[0]
return data, label
def next(self):
return self.__next__() # for Python 2
data_iter = mx.io.NDArrayIter(data=X, label=y, batch_size=5)
data_iter_loader = DataIterLoader(data_iter)
for X_batch, y_batch in data_iter_loader:
assert X_batch.shape == (5, 3)
assert y_batch.shape == (5, 1)