mxnet支持两种数据读取方式:一种是从原图读取,另一种是从mxnet特有的数据格式.rec数据读取(类似caffe的lmdb和TensorFlow的 tfrecords)。
mxnet提供工具来制作数据(类似caffe),这里只介绍如何使用这个工具
以kaggle猫狗大战的数据集为例
将cat和dog各自移动到一个文件夹
mkdir dog
mkdir cat
mv dog.* dog/
mv cat.* cat/
是用如下shell脚本,调用tools/im2rec.py制作 lst文件(看来caffe制作lmdb需要的txt路径文件)
注意路径
python3 ~/mxnet/tools/im2rec.py /home/~/data/ /home/~/data/train/ --list --recursive --train-ratio 0.9
这时候在 /home/~/data下就有data_train.lst 和data_val.lst文件
用如下shell脚本,制作rec文件(类似caffe的lmdb)
python3 ~/mxnet/tools/im2rec.py --num-thread 32 /home/~/data_train.lst /home/~/data/train/
意思就是根据lst文件生成rec文件
简单看一下lst文件的内容
3197 0.000000 cat/cat.1625.jpg
15084 1.000000 dog/dog.12322.jpg
1479 0.000000 cat/cat.11328.jpg
5262 0.000000 cat/cat.3484.jpg
20714 1.000000 dog/dog.6140.jpg
后面的是相对路径,所以在生成rec文件时,还需要补全前面的路径,上面脚本第四个参数
1、mx.image.ImageIter()
看参数解释,这里直接谷歌翻译一下
具有大量扩充选择的“”“图像数据迭代器。
该迭代器支持读取.rec文件和原始图像文件。
要从.rec文件加载输入图像,请使用path_imgrec参数,并从原始图像加载
文件,使用`path_imglist`和`path_root`参数。
要使用数据分区(用于分布式训练)或改组,请指定`path_imgidx`参数。
参量
----------
batch_size:整数
每批次的示例数。
data_shape:元组
数据形状(通道,高度,宽度)格式。
目前,仅支持具有3个通道的RGB图像。
label_width:int,可选
每个示例的标签数。默认标签宽度为1。
path_imgrec:str
图像记录文件(.rec)的路径。
使用tools / im2rec.py或bin / im2rec创建。
path_imglist:str
图像列表(.lst)的路径。
使用tools / im2rec.py或自定义脚本创建。
格式:索引的制表符分隔记录,一个或多个标签和relative_path_from_root。
imglist:列表
带有标签的图像列表。
每个项目都是一个列表[imagelabel:float或float列表,imgpath]。
path_root:str
图像文件的根文件夹。
path_imgidx:str
图像索引文件的路径。使用.rec源时需要进行分区和改组。
随机播放:布尔
是否在每次迭代开始时随机播放所有图像。
对于HDD可能很慢。
part_index:整数
分区索引。
num_parts:int
分区总数。
data_name:str
提供的符号的数据名称。
label_name:str
提供的符号的标签名称。
dtype:str
标签数据类型。默认值:float32。其他选项:int32,int64,float64
last_batch_handle:str,可选
如何处理最后一批。
此参数可以是“ pad”(默认),“ discard”或“ roll_over”。
如果是'pad',则将从最后开始填充最后一批数据
如果“放弃”,最后一批将被丢弃
如果为'roll_over',其余元素将被移至下一个迭代
kwargs:...
有关创建增强器的更多参数。请参阅mx.image.CreateAugmenter。
“”
可以根据lst读取,也可以根据rec文件读取
train_data = mx.image.ImageIter(batch_size=32,
data_shape=(3, 224, 224),
path_imglist='../data_train.lst',
path_root='../data/train', # 图像在的目录
shuffle=True)
train_data.reset()
# print(train_data)
data_batch = train_data.next()
data = data_batch.data[0]
print(data_batch)
print(data.shape)
print(data[0].shape)
plt.figure()
for i in range(5):
save_image = data[i].astype('uint8').asnumpy().transpose((1, 2, 0))
plt.subplot(1, 5, i+1)
plt.imshow(save_image)
plt.show()
根据rec文件读取
train_data = mx.image.ImageIter(batch_size=32,
data_shape=(3, 224, 224),
path_imgrec="../data_train.rec",
path_imgidx="../data_train.idx",
shuffle=True)
train_data.reset()
data_batch = train_data.next()
data = data_batch.data[0]
print(data_batch)
print(data.shape)
print(data[0].shape)
plt.figure()
for i in range(5):
save_image = data[i].astype('uint8').asnumpy().transpose((1, 2, 0))
plt.subplot(1, 5, i + 1)
plt.imshow(save_image)
plt.show()
2、mx.io.ImageRecordIter()
读取rec文件,用法几乎一样
train_data = mx.io.ImageRecordIter(path_imgrec="../data_train.rec", # the target record file
path_imgidx="../data_train.idx",
data_shape=(3, 224, 224), # output data shape. An 227x227 region will be cropped from the original image.
batch_size=32,
shuffle=True)
train_data.reset()
data_batch = train_data.next()
data = data_batch.data[0]
print(data_batch)
print(data.shape)
print(data[0].shape)
plt.figure()
for i in range(5):
save_image = data[i].astype('uint8').asnumpy().transpose((1, 2, 0))
plt.subplot(1, 5, i + 1)
plt.imshow(save_image)
plt.show()
值得注意的是,
mx.image.ImageDetIter()是用来读取检测相关数据的,label的形状是标号的形状是batch_size x num_object_per_image x 5。