读取图片数据增强,代码,转载,可以直接使用,https://blog.csdn.net/qq_35606924/article/details/80033850
注意 path_root= “文件夹” #图片的根目录
这边的 lst文件如下,索引 label 图片相对路径,是 im2rec.py生成的,
生成的图片路径一定要是 文件夹/文件1/all jpg 文件夹/文件2/all jpg 文件夹/文件3/all jpg,格式目录,
python im2rec.py --list --recursive train imgpath
python im2rec.py --list --recursive train 文件夹,当前目录对文件夹下的所有分类图片生成train.lst 文件,
python im2rec.py --list --recursive test 文件夹test ,当前目录对文件夹下的所有分类图片生成test.lst 文件,
177 19.000000 17271/IMG20191018175330.jpg
47 5.000000 17161/IMG_20191018_175135.jpg
40 4.000000 17157/IMG_20191018_175540.jpg
230 25.000000 690/IMG_20191018_175252.jpg
162 18.000000 17253/IMG_20191018_175111.jpg
157 17.000000 17238/IMG_20191018_175507.jpg
##########################
###############
import mxnet as mx
class custom_iter(mx.io.DataIter):
def __init__(self, data_iter):
super(custom_iter,self).__init__()
self.data_iter = data_iter
self.batch_size = self.data_iter.batch_size
@property
def provide_data(self):
return self.data_iter.provide_data
@property
def provide_label(self):
provide_label = self.data_iter.provide_label[0]
#return [('softmax_label', provide_label[1]), \
# ('other_loss_label', provide_label[1])]
return [('softmax_label', provide_label[1])]
def hard_reset(self):
self.data_iter.hard_reset()
def reset(self):
self.data_iter.reset()
def next(self):
batch = self.data_iter.next()
label = batch.label[0]
return mx.io.DataBatch(data=batch.data, label=[label,label], \
pad=batch.pad, index=batch.index)
import numpy as np
eigval = np.array([55.46, 4.794, 1.148])
eigvec = np.array([[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]])
shape_=112
shape=(3,shape_,shape_)
aug_list_test=[mx.image.ForceResizeAug(size=(shape_,shape_)),
#mx.image.ResizeAug(size=shape_+32),
mx.image.CenterCropAug((shape_,shape_)),
]
aug_list_train=[
#mx.image.ResizeAug(size=shape_+32),
mx.image.ForceResizeAug(size=(shape_,shape_)),
mx.image.RandomCropAug((shape_,shape_)),
mx.image.HorizontalFlipAug(0.5),
mx.image.CastAug(),
mx.image.ColorJitterAug(0.0, 0.1, 0.1),
mx.image.HueJitterAug(0.5),
mx.image.LightingAug(0.1, eigval, eigvec),
]
def get_iterator(batch_size):
"""return train and val iterators for training"""
train_iter = mx.image.ImageIter(batch_size=batch_size,
data_shape=shape,
label_width=1,
aug_list=aug_list_train,
shuffle=True,
path_root='',
path_imglist='/you/path/train.lst'
)
val_iter = mx.image.ImageIter(batch_size=batch_size,
data_shape=shape,
label_width=1,
shuffle=False,
aug_list=aug_list_test,
path_root='',
path_imglist='/you/path/val.lst'
)
val_iter = mx.image.ImageIter(batch_size=batch_size,
data_shape=shape,
label_width=1,
shuffle=False,
aug_list=aug_list_test,
path_root='',
path_imglist='/you/path/val.lst'
)
return (custom_iter(train_iter), custom_iter(val_iter))
train_dataiter = mx.io.PrefetchingIter(train_dataiter) #多线程迭代器