MXNet之数据读取与增强

原文链接: https://blog.csdn.net/weixin_39451323/article/details/93659168
原创

MXNet之数据读取与增强

数据读取与增强

  • 1.直接读取原图像数据
    • 1.1 生成.lst文件
    • 1.2 基本数据读取方式read_lst.py
  • 2. 基于RecordIO文件读取数据
    • 1.生成RecordIO文件
    • 2.2 RecordIO数据读取方式read_rec.py
  • 3.数据增强
    • 3.1大小重置(resize)
    • 3.2 裁剪(crop)
    • 3.3 镜像(mirror)
    • 3.4 亮度 (brightness)
    • 3.5 对比度(contrast)
    • 3.6 饱和度(saturation)

1.直接读取原图像数据

1.1 生成.lst文件

im2rec.py脚本用于生成,lst文件.
文件目录:
MXNet之数据读取与增强_第1张图片
MXNet之数据读取与增强_第2张图片

执行下面指令,生成.lst文件.

#第一个data/train代表,prefix的意思,执行后会得到train.lst文件
#第2个data/train代表,数据的root,根据数据存储路径进行设置
#--list代表执行生成.lst操作,还可用来生成RecordIO文件
#--recursive表示迭代搜索给定的目录,例如train下是类别文件夹,然后是图片,需要设置此参数,才能搜索到.
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python /tools/im2rec.py data/train data/train --list --recursive
python: can't open file '/tools/im2rec.py': [Errno 2] No such file or directory
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py data/train data/train --list --recursive
cock 0
ostrich 1
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py data/val data/val --list --recursive
cock 0
ostrich 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

结果:
MXNet之数据读取与增强_第3张图片
MXNet之数据读取与增强_第4张图片

1.2 基本数据读取方式read_lst.py

import mxnet as mx
import matplotlib.pyplot as plt

train_data = mx.image.ImageIter(batch_size=32,
data_shape=(3,224,224),
path_imglist=‘data/train.lst’,
path_root=‘data/train’,
shuffle=True)
val_data = mx.image.ImageIter(batch_size=32,
data_shape=(3,224,224),
path_imglist=‘data/val.lst’,
path_root=‘data/val’)
train_data.reset()
print(train_data)
data_batch = train_data.next()
print(data_batch)
data = data_batch.data[0]
#print(data)
plt.figure()
for i in range(4):
save_image = data[i].astype(‘uint8’).asnumpy().transpose((1,2,0))
plt.subplot(1,4,i+1)
plt.imshow(save_image)
plt.savefig(‘image_sample.jpg’)

train_data = mx.image.ImageIter(batch_size=32,
data_shape=(3, 224, 224),
path_imglist=‘data/train.lst’,
path_root=‘data/train’,
shuffle=True,
resize=256,
rand_mirror=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

输出结果:

<mxnet.image.image.ImageIter object at 0x7fc7ac5839e8>
DataBatch: data shapes: [(32, 3, 224, 224)] label shapes: [(32,)]

   
   
   
   
  • 1
  • 2

MXNet之数据读取与增强_第5张图片

2. 基于RecordIO文件读取数据

1.生成RecordIO文件

#--num-thread:设置线程数
#data/train.lst:表示.lst的文件路径
#data/train:原图像所在的目录
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py --num-thread 16 data/train.lst data/train
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data/train.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data
time: 0.009440183639526367  count: 0
time: 0.8572912216186523  count: 1000
time: 0.8444476127624512  count: 2000
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py --num-thread 16 data/val.lst data/val
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data/val.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data
time: 0.014867544174194336  count: 0

   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

结果:
MXNet之数据读取与增强_第6张图片
MXNet之数据读取与增强_第7张图片

2.2 RecordIO数据读取方式read_rec.py

mx.io.ImageRecordIter()接口用于读取RecordIO数据文件,接口参数多大数十个,以下代码简介:

import mxnet as mx
import matplotlib.pyplot as plt
#增加了resize和随机镜像操作
train_data = mx.io.ImageRecordIter(batch_size=32,
                                   data_shape=(3,224,224),
                                   path_imgrec='data/train.rec',
                                   path_imgidx='data/train.idx',
                                   shuffle=True,
                                   resize=256,
                                   rand_mirror=True)

val_data = mx.io.ImageRecordIter(batch_size=32,
data_shape=(3,224,224),
path_imgrec=‘data/val.rec’,
path_imgidx=‘data/val.idx’,
resize=256)

train_data.reset()
data_batch = train_data.next()
data = data_batch.data[0]
plt.figure()
for i in range(4):
save_image = data[i].astype(‘uint8’).asnumpy().transpose((1,2,0))
plt.subplot(1,4,i+1)
plt.imshow(save_image)
plt.savefig(‘image_sample_rec.jpg’)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

输出结果:

<mxnet.io.io.MXDataIter object at 0x7fc7ac62fa58>
DataBatch: data shapes: [(32, 3, 224, 224)] label shapes: [(32,)]

   
   
   
   
  • 1
  • 2

MXNet之数据读取与增强_第8张图片

3.数据增强

3.1大小重置(resize)

  • mx.image.ResizeAug(size=)
  • mx.image.ForceResizeAug(size=(224,224))
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
prefix = ‘data-augmentation/resize/’
image = ‘ILSVRC2012_val_00000002.jpg’
image_name = image.split(".")[0] #获取ILSVRC2012_val_00000002
image_string = open(‘data-augmentation/resize/{}’.format(image), ‘rb’).read() #以2进制读取图片
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(prefix + image_name))

#shortersize短边到size,按比例缩小另一边
shorterResize = mx.image.ResizeAug(size=224, interp=2)#interp=2是插值算法,默认是2,可不写
shorterResize_data = shorterResize(data)
print("Shape of data:{}".format(shorterResize_data.shape))
plt.imshow(shorterResize_data.asnumpy())
plt.savefig('{}_shorterResize.png'.format(prefix + image_name))

#shortersize短边到size,按比例夸张另一边
shorterResize = mx.image.ResizeAug(size=1000)
shorterResize_data = shorterResize(data)
print("Shape of data:{}".format(shorterResize_data.shape))
plt.imshow(shorterResize_data.asnumpy())
plt.savefig('{}_shorterResize_bigsize.png'.format(prefix + image_name))

#强制转换成(224,224),通过插值算法,人会变形
forceResize = mx.image.ForceResizeAug(size=(224,224))
forceResize_data = forceResize(data)
print("Shape of data:{}".format(forceResize_data.shape))
plt.imshow(forceResize_data.asnumpy())
plt.savefig('{}_forceResize.png'.format(prefix + image_name))

#强制转换成(200,300)
forceResize = mx.image.ForceResizeAug(size=(200, 300))
forceResize_data = forceResize(data)
print("Shape of data:{}".format(forceResize_data.shape))
plt.imshow(forceResize_data.asnumpy())
plt.savefig('{}_forceResize_diff.png'.format(prefix + image_name))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

输出结果:

Shape of data:(1440, 1080, 3)
Shape of data:(298, 224, 3)
Shape of data:(1333, 1000, 3)
Shape of data:(224, 224, 3)
Shape of data:(300, 200, 3)

   
   
   
   
  • 1
  • 2
  • 3
  • 4
  • 5

MXNet之数据读取与增强_第9张图片
MXNet之数据读取与增强_第10张图片
MXNet之数据读取与增强_第11张图片
MXNet之数据读取与增强_第12张图片
MXNet之数据读取与增强_第13张图片
MXNet之数据读取与增强_第14张图片

3.2 裁剪(crop)

  • center crop
  • random crop
  • random resize crop
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
prefix = ‘data-augmentation/crop/’
image = ‘ILSVRC2012_val_00000009.jpg’
image_name = image.split(".")[0]
image_string = open(‘data-augmentation/crop/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(prefix + image_name))

#从图像的中间区域裁剪,以图像中心点为裁剪中心
centerCrop = mx.image.CenterCropAug(size=(224,224))
class_centerCrop_data = centerCrop(data)
print(“Shape of data:{}”.format(class_centerCrop_data.shape))
plt.imshow(class_centerCrop_data.asnumpy())
plt.savefig(’{}_centerCrop.png’.format(prefix + image_name))

#随机裁剪,以图像任意点为裁剪中心
randomCrop = mx.image.RandomCropAug(size=(224,224))
class_randomCrop_data = randomCrop(data)
print(“Shape of data:{}”.format(class_randomCrop_data.shape))
plt.imshow(class_randomCrop_data.asnumpy())
plt.savefig(’{}_randomCrop.png’.format(prefix + image_name))

#size:输出图线尺寸;area:初次裁剪面积又area*原图面积得道;ratio:表示宽高比例,确定裁剪面积后,根据此比例确定宽高
randomSizeCrop = mx.image.RandomSizedCropAug(size=(224,224), area=0.08,
ratio=(3/4, 4/3))
class_randomSizedCrop_data = randomSizeCrop(data)
print(“Shape of data:{}”.format(class_randomSizedCrop_data.shape))
plt.imshow(class_randomSizedCrop_data.asnumpy())
plt.savefig(’{}_randomSizedCrop.png’.format(prefix + image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

输出结果:

Shape of data:(1440, 1080, 3)
Shape of data:(224, 224, 3)
Shape of data:(224, 224, 3)
Shape of data:(224, 224, 3)

   
   
   
   
  • 1
  • 2
  • 3
  • 4

MXNet之数据读取与增强_第15张图片
MXNet之数据读取与增强_第16张图片
MXNet之数据读取与增强_第17张图片
MXNet之数据读取与增强_第18张图片
MXNet之数据读取与增强_第19张图片

3.3 镜像(mirror)

  • mx.image.HorizontalFlipAug(p=0.5)
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
image = ‘ILSVRC2012_val_00000014.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
print(“Shape of data:{}”.format(data.shape))
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))

#p表示执行随机镜像操作的概率
mirror = mx.image.HorizontalFlipAug(p=0.5)
mirror_data = mirror(data)
plt.imshow(mirror_data.asnumpy())
plt.savefig(’{}_mirror.png’.format(image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

MXNet之数据读取与增强_第20张图片
MXNet之数据读取与增强_第21张图片

3.4 亮度 (brightness)

  • mx.image.BrightnessJitterAug(brightness=0.3)
    亮度不宜设置过大,否则失真
    输出图像:输入图像的像素值乘以[1+brightness,1-brightness]中间的随机数的到的图像
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))

cast = mx.image.CastAug()
data = cast(data)
brightness = mx.image.BrightnessJitterAug(brightness=0.3)
brightness_data = brightness(data)
brightness_data = mx.nd.Cast(brightness_data, dtype=‘uint8’)
plt.imshow(brightness_data.asnumpy())
plt.savefig(’{}_brightness.png’.format(image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

输出结果:MXNet之数据读取与增强_第22张图片
MXNet之数据读取与增强_第23张图片

3.5 对比度(contrast)

  • mx.image.ContrastJitterAug(contrast=0.3)
    对比度不宜设置过大,否则失真
import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))

cast = mx.image.CastAug()
data = cast(data)
contrast = mx.image.ContrastJitterAug(contrast=0.3)
contrast_data = contrast(data)
contrast_data = mx.nd.Cast(contrast_data, dtype=‘uint8’)
plt.imshow(contrast_data.asnumpy())
plt.savefig(’{}_contrast.png’.format(image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

输出结果:
MXNet之数据读取与增强_第24张图片
MXNet之数据读取与增强_第25张图片

3.6 饱和度(saturation)

  • saturation = mx.image.SaturationJitterAug(saturation=0.3)
    饱和度指色彩纯度,纯度越高表现越鲜明,越低表现越黯淡.

import matplotlib.pyplot as plt
import mxnet as mx

if name == main:
image = ‘ILSVRC2012_val_00000008.JPEG’
image_name = image.split(".")[0]
image_string = open(’…/image/{}’.format(image), ‘rb’).read()
data = mx.image.imdecode(image_string, flag=1)
plt.imshow(data.asnumpy())
plt.savefig(’{}_original.png’.format(image_name))

cast = mx.image.CastAug()
data = cast(data)
saturation = mx.image.SaturationJitterAug(saturation=0.3)
saturation_data = saturation(data)
saturation_data = mx.nd.Cast(saturation_data, dtype=‘uint8’)
plt.imshow(saturation_data.asnumpy())
plt.savefig(’{}_saturation.png’.format(image_name))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

输出结果:
MXNet之数据读取与增强_第26张图片
MXNet之数据读取与增强_第27张图片

                                
文章最后发布于: 2019-06-26 10:07:52

你可能感兴趣的:(MXNet之数据读取与增强)