mxnet学习一、数据制作以及读取

文章目录

    • 一、制作数据
    • 二、mxnet的几种读取数据方式

一、制作数据

mxnet支持两种数据读取方式:一种是从原图读取,另一种是从mxnet特有的数据格式.rec数据读取(类似caffe的lmdb和TensorFlow的 tfrecords)。
mxnet提供工具来制作数据(类似caffe),这里只介绍如何使用这个工具
以kaggle猫狗大战的数据集为例
将cat和dog各自移动到一个文件夹

mkdir dog
mkdir cat
mv dog.* dog/
mv cat.* cat/

是用如下shell脚本,调用tools/im2rec.py制作 lst文件(看来caffe制作lmdb需要的txt路径文件)
注意路径

python3 ~/mxnet/tools/im2rec.py /home/~/data/ /home/~/data/train/ --list --recursive --train-ratio 0.9

这时候在 /home/~/data下就有data_train.lst 和data_val.lst文件

用如下shell脚本,制作rec文件(类似caffe的lmdb)

python3 ~/mxnet/tools/im2rec.py --num-thread 32 /home//data_train.lst  /home//data/train/

意思就是根据lst文件生成rec文件
简单看一下lst文件的内容
3197 0.000000 cat/cat.1625.jpg
15084 1.000000 dog/dog.12322.jpg
1479 0.000000 cat/cat.11328.jpg
5262 0.000000 cat/cat.3484.jpg
20714 1.000000 dog/dog.6140.jpg
后面的是相对路径,所以在生成rec文件时,还需要补全前面的路径,上面脚本第四个参数

二、mxnet的几种读取数据方式

1、mx.image.ImageIter()
看参数解释,这里直接谷歌翻译一下

具有大量扩充选择的“”“图像数据迭代器。
该迭代器支持读取.rec文件和原始图像文件。

要从.rec文件加载输入图像,请使用path_imgrec参数,并从原始图像加载
文件,使用`path_imglist`和`path_root`参数。

要使用数据分区(用于分布式训练)或改组,请指定`path_imgidx`参数。

参量
----------
batch_size:整数
    每批次的示例数。
data_shape:元组
    数据形状(通道,高度,宽度)格式。
    目前,仅支持具有3个通道的RGB图像。
label_width:int,可选
    每个示例的标签数。默认标签宽度为1。
path_imgrec:str
    图像记录文件(.rec)的路径。
    使用tools / im2rec.py或bin / im2rec创建。
path_imglist:str
    图像列表(.lst)的路径。
    使用tools / im2rec.py或自定义脚本创建。
    格式:索引的制表符分隔记录,一个或多个标签和relative_path_from_root。
imglist:列表
    带有标签的图像列表。
    每个项目都是一个列表[imagelabel:float或float列表,imgpath]。
path_root:str
    图像文件的根文件夹。
path_imgidx:str
    图像索引文件的路径。使用.rec源时需要进行分区和改组。
随机播放:布尔
    是否在每次迭代开始时随机播放所有图像。
    对于HDD可能很慢。
part_index:整数
    分区索引。
num_parts:int
    分区总数。
data_name:str
    提供的符号的数据名称。
label_name:str
    提供的符号的标签名称。
dtype:str
    标签数据类型。默认值:float32。其他选项:int32,int64,float64
last_batch_handle:str,可选
    如何处理最后一批。
    此参数可以是“ pad”(默认),“ discard”或“ roll_over”。
    如果是'pad',则将从最后开始填充最后一批数据
    如果“放弃”,最后一批将被丢弃
    如果为'roll_over',其余元素将被移至下一个迭代
kwargs:...
    有关创建增强器的更多参数。请参阅mx.image.CreateAugmenter。
“”

可以根据lst读取,也可以根据rec文件读取

 train_data = mx.image.ImageIter(batch_size=32,
                                 data_shape=(3, 224, 224),
                                 path_imglist='../data_train.lst',
                                 path_root='../data/train',  # 图像在的目录
                                 shuffle=True)

 train_data.reset()
 # print(train_data)

 data_batch = train_data.next()
 data = data_batch.data[0]
 print(data_batch)
 print(data.shape)
 print(data[0].shape)
 plt.figure()
 for i in range(5):
     save_image = data[i].astype('uint8').asnumpy().transpose((1, 2, 0))
     plt.subplot(1, 5, i+1)
     plt.imshow(save_image)
 plt.show()

根据rec文件读取

 train_data = mx.image.ImageIter(batch_size=32,
                                       data_shape=(3, 224, 224),
                                       path_imgrec="../data_train.rec",
                                       path_imgidx="../data_train.idx",
                                       shuffle=True)

 train_data.reset()
 data_batch = train_data.next()
 data = data_batch.data[0]
 print(data_batch)
 print(data.shape)
 print(data[0].shape)
 plt.figure()
 for i in range(5):
     save_image = data[i].astype('uint8').asnumpy().transpose((1, 2, 0))
     plt.subplot(1, 5, i + 1)
     plt.imshow(save_image)
 plt.show()

2、mx.io.ImageRecordIter()
读取rec文件,用法几乎一样

    train_data = mx.io.ImageRecordIter(path_imgrec="../data_train.rec",  # the target record file
                                       path_imgidx="../data_train.idx",
                                       data_shape=(3, 224, 224),          # output data shape. An 227x227 region will be cropped from the original image.
                                       batch_size=32,
                                       shuffle=True)

    train_data.reset()
    data_batch = train_data.next()
    data = data_batch.data[0]
    print(data_batch)
    print(data.shape)
    print(data[0].shape)
    plt.figure()
    for i in range(5):
        save_image = data[i].astype('uint8').asnumpy().transpose((1, 2, 0))
        plt.subplot(1, 5, i + 1)
        plt.imshow(save_image)
    plt.show()

值得注意的是,
mx.image.ImageDetIter()是用来读取检测相关数据的,label的形状是标号的形状是batch_size x num_object_per_image x 5。

你可能感兴趣的:(Mxnet)