MXNet之数据读取与增强

数据读取与增强

  • 1.直接读取原图像数据
    • 1.1 生成.lst文件
    • 1.2 基本数据读取方式read_lst.py
  • 2. 基于RecordIO文件读取数据
    • 1.生成RecordIO文件
    • 2.2 RecordIO数据读取方式read_rec.py
  • 3.数据增强
    • 3.1大小重置(resize)
    • 3.2 裁剪(crop)
    • 3.3 镜像(mirror)
    • 3.4 亮度 (brightness)
    • 3.5 对比度(contrast)
    • 3.6 饱和度(saturation)

1.直接读取原图像数据

1.1 生成.lst文件

im2rec.py脚本用于生成,lst文件.
文件目录:
在这里插入图片描述
在这里插入图片描述

执行下面指令,生成.lst文件.

#第一个data/train代表,prefix的意思,执行后会得到train.lst文件
#第2个data/train代表,数据的root,根据数据存储路径进行设置
#--list代表执行生成.lst操作,还可用来生成RecordIO文件
#--recursive表示迭代搜索给定的目录,例如train下是类别文件夹,然后是图片,需要设置此参数,才能搜索到.
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python /tools/im2rec.py data/train data/train --list --recursive
python: can't open file '/tools/im2rec.py': [Errno 2] No such file or directory
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py data/train data/train --list --recursive
cock 0
ostrich 1
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py data/val data/val --list --recursive
cock 0
ostrich 1

结果:
在这里插入图片描述
在这里插入图片描述

1.2 基本数据读取方式read_lst.py

import mxnet as mx
import matplotlib.pyplot as plt

train_data = mx.image.ImageIter(batch_size=32,
                                data_shape=(3,224,224),
                                path_imglist='data/train.lst',
                                path_root='data/train',
                                shuffle=True)
val_data = mx.image.ImageIter(batch_size=32,
                              data_shape=(3,224,224),
                              path_imglist='data/val.lst',
                              path_root='data/val')
train_data.reset()
print(train_data)
data_batch = train_data.next()
print(data_batch)
data = data_batch.data[0]
#print(data)
plt.figure()
for i in range(4):
    save_image = data[i].astype('uint8').asnumpy().transpose((1,2,0))
    plt.subplot(1,4,i+1)
    plt.imshow(save_image)
plt.savefig('image_sample.jpg')

train_data = mx.image.ImageIter(batch_size=32,
                                data_shape=(3, 224, 224),
                                path_imglist='data/train.lst',
                                path_root='data/train',
                                shuffle=True,
                                resize=256,
                                rand_mirror=True)

输出结果:

<mxnet.image.image.ImageIter object at 0x7fc7ac5839e8>
DataBatch: data shapes: [(32, 3, 224, 224)] label shapes: [(32,)]

在这里插入图片描述

2. 基于RecordIO文件读取数据

1.生成RecordIO文件

#--num-thread:设置线程数
#data/train.lst:表示.lst的文件路径
#data/train:原图像所在的目录
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py --num-thread 16 data/train.lst data/train
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data/train.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data
time: 0.009440183639526367  count: 0
time: 0.8572912216186523  count: 1000
time: 0.8444476127624512  count: 2000
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo5$ python ./tools/im2rec.py --num-thread 16 data/val.lst data/val
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data/val.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo5/data
time: 0.014867544174194336  count: 0

结果:
在这里插入图片描述
在这里插入图片描述

2.2 RecordIO数据读取方式read_rec.py

mx.io.ImageRecordIter()接口用于读取RecordIO数据文件,接口参数多大数十个,以下代码简介:

import mxnet as mx
import matplotlib.pyplot as plt
#增加了resize和随机镜像操作
train_data = mx.io.ImageRecordIter(batch_size=32,
                                   data_shape=(3,224,224),
                                   path_imgrec='data/train.rec',
                                   path_imgidx='data/train.idx',
                                   shuffle=True,
                                   resize=256,
                                   rand_mirror=True)

val_data = mx.io.ImageRecordIter(batch_size=32,
                                 data_shape=(3,224,224),
                                 path_imgrec='data/val.rec',
                                 path_imgidx='data/val.idx',
                                 resize=256)

train_data.reset()
data_batch = train_data.next()
data = data_batch.data[0]
plt.figure()
for i in range(4):
    save_image = data[i].astype('uint8').asnumpy().transpose((1,2,0))
    plt.subplot(1,4,i+1)
    plt.imshow(save_image)
plt.savefig('image_sample_rec.jpg')

输出结果:

<mxnet.io.io.MXDataIter object at 0x7fc7ac62fa58>
DataBatch: data shapes: [(32, 3, 224, 224)] label shapes: [(32,)]

在这里插入图片描述

3.数据增强

3.1大小重置(resize)

  • mx.image.ResizeAug(size=)
  • mx.image.ForceResizeAug(size=(224,224))
import matplotlib.pyplot as plt
import mxnet as mx

if __name__ == '__main__':
    prefix = 'data-augmentation/resize/'
    image = 'ILSVRC2012_val_00000002.jpg'
    image_name = image.split(".")[0]  #获取ILSVRC2012_val_00000002
    image_string = open('data-augmentation/resize/{}'.format(image), 'rb').read()  #以2进制读取图片
    data = mx.image.imdecode(image_string, flag=1)
    print("Shape of data:{}".format(data.shape))
    plt.imshow(data.asnumpy())
    plt.savefig('{}_original.png'.format(prefix + image_name))

    #shortersize短边到size,按比例缩小另一边
    shorterResize = mx.image.ResizeAug(size=224, interp=2)#interp=2是插值算法,默认是2,可不写
    shorterResize_data = shorterResize(data)
    print("Shape of data:{}".format(shorterResize_data.shape))
    plt.imshow(shorterResize_data.asnumpy())
    plt.savefig('{}_shorterResize.png'.format(prefix + image_name))
    
    #shortersize短边到size,按比例夸张另一边
    shorterResize = mx.image.ResizeAug(size=1000)
    shorterResize_data = shorterResize(data)
    print("Shape of data:{}".format(shorterResize_data.shape))
    plt.imshow(shorterResize_data.asnumpy())
    plt.savefig('{}_shorterResize_bigsize.png'.format(prefix + image_name))
    
    #强制转换成(224,224),通过插值算法,人会变形
    forceResize = mx.image.ForceResizeAug(size=(224,224))
    forceResize_data = forceResize(data)
    print("Shape of data:{}".format(forceResize_data.shape))
    plt.imshow(forceResize_data.asnumpy())
    plt.savefig('{}_forceResize.png'.format(prefix + image_name))
    
    #强制转换成(200,300)
    forceResize = mx.image.ForceResizeAug(size=(200, 300))
    forceResize_data = forceResize(data)
    print("Shape of data:{}".format(forceResize_data.shape))
    plt.imshow(forceResize_data.asnumpy())
    plt.savefig('{}_forceResize_diff.png'.format(prefix + image_name))

输出结果:

Shape of data:(1440, 1080, 3)
Shape of data:(298, 224, 3)
Shape of data:(1333, 1000, 3)
Shape of data:(224, 224, 3)
Shape of data:(300, 200, 3)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.2 裁剪(crop)

  • center crop
  • random crop
  • random resize crop
import matplotlib.pyplot as plt
import mxnet as mx

if __name__ == '__main__':
   prefix = 'data-augmentation/crop/'
   image = 'ILSVRC2012_val_00000009.jpg'
   image_name = image.split(".")[0]
   image_string = open('data-augmentation/crop/{}'.format(image), 'rb').read()
   data = mx.image.imdecode(image_string, flag=1)
   print("Shape of data:{}".format(data.shape))
   plt.imshow(data.asnumpy())
   plt.savefig('{}_original.png'.format(prefix + image_name))
   
   #从图像的中间区域裁剪,以图像中心点为裁剪中心
   centerCrop = mx.image.CenterCropAug(size=(224,224))
   class_centerCrop_data = centerCrop(data)
   print("Shape of data:{}".format(class_centerCrop_data.shape))
   plt.imshow(class_centerCrop_data.asnumpy())
   plt.savefig('{}_centerCrop.png'.format(prefix + image_name))
   
   #随机裁剪,以图像任意点为裁剪中心
   randomCrop = mx.image.RandomCropAug(size=(224,224))
   class_randomCrop_data = randomCrop(data)
   print("Shape of data:{}".format(class_randomCrop_data.shape))
   plt.imshow(class_randomCrop_data.asnumpy())
   plt.savefig('{}_randomCrop.png'.format(prefix + image_name))
   
   #size:输出图线尺寸;area:初次裁剪面积又area*原图面积得道;ratio:表示宽高比例,确定裁剪面积后,根据此比例确定宽高
   randomSizeCrop = mx.image.RandomSizedCropAug(size=(224,224), area=0.08, 
                                                ratio=(3/4, 4/3))
   class_randomSizedCrop_data = randomSizeCrop(data)
   print("Shape of data:{}".format(class_randomSizedCrop_data.shape))
   plt.imshow(class_randomSizedCrop_data.asnumpy())
   plt.savefig('{}_randomSizedCrop.png'.format(prefix + image_name))

输出结果:

Shape of data:(1440, 1080, 3)
Shape of data:(224, 224, 3)
Shape of data:(224, 224, 3)
Shape of data:(224, 224, 3)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.3 镜像(mirror)

  • mx.image.HorizontalFlipAug(p=0.5)
import matplotlib.pyplot as plt
import mxnet as mx

if __name__ == '__main__':
   image = 'ILSVRC2012_val_00000014.JPEG'
   image_name = image.split(".")[0]
   image_string = open('../image/{}'.format(image), 'rb').read()
   data = mx.image.imdecode(image_string, flag=1)
   print("Shape of data:{}".format(data.shape))
   plt.imshow(data.asnumpy())
   plt.savefig('{}_original.png'.format(image_name))
   
   #p表示执行随机镜像操作的概率
   mirror = mx.image.HorizontalFlipAug(p=0.5)
   mirror_data = mirror(data)
   plt.imshow(mirror_data.asnumpy())
   plt.savefig('{}_mirror.png'.format(image_name))

在这里插入图片描述
在这里插入图片描述

3.4 亮度 (brightness)

  • mx.image.BrightnessJitterAug(brightness=0.3)
    亮度不宜设置过大,否则失真
    输出图像:输入图像的像素值乘以[1+brightness,1-brightness]中间的随机数的到的图像
import matplotlib.pyplot as plt
import mxnet as mx

if __name__ == '__main__':
   image = 'ILSVRC2012_val_00000008.JPEG'
   image_name = image.split(".")[0]
   image_string = open('../image/{}'.format(image), 'rb').read()
   data = mx.image.imdecode(image_string, flag=1)
   plt.imshow(data.asnumpy())
   plt.savefig('{}_original.png'.format(image_name))

   cast = mx.image.CastAug()
   data = cast(data)
   brightness = mx.image.BrightnessJitterAug(brightness=0.3)
   brightness_data = brightness(data)
   brightness_data = mx.nd.Cast(brightness_data, dtype='uint8')
   plt.imshow(brightness_data.asnumpy())
   plt.savefig('{}_brightness.png'.format(image_name))

输出结果:在这里插入图片描述
在这里插入图片描述

3.5 对比度(contrast)

  • mx.image.ContrastJitterAug(contrast=0.3)
    对比度不宜设置过大,否则失真
import matplotlib.pyplot as plt
import mxnet as mx

if __name__ == '__main__':
   image = 'ILSVRC2012_val_00000008.JPEG'
   image_name = image.split(".")[0]
   image_string = open('../image/{}'.format(image), 'rb').read()
   data = mx.image.imdecode(image_string, flag=1)
   plt.imshow(data.asnumpy())
   plt.savefig('{}_original.png'.format(image_name))

   cast = mx.image.CastAug()
   data = cast(data)
   contrast = mx.image.ContrastJitterAug(contrast=0.3)
   contrast_data = contrast(data)
   contrast_data = mx.nd.Cast(contrast_data, dtype='uint8')
   plt.imshow(contrast_data.asnumpy())
   plt.savefig('{}_contrast.png'.format(image_name))

输出结果:
在这里插入图片描述
在这里插入图片描述

3.6 饱和度(saturation)

  • saturation = mx.image.SaturationJitterAug(saturation=0.3)
    饱和度指色彩纯度,纯度越高表现越鲜明,越低表现越黯淡.

import matplotlib.pyplot as plt
import mxnet as mx

if __name__ == '__main__':
  image = 'ILSVRC2012_val_00000008.JPEG'
  image_name = image.split(".")[0]
  image_string = open('../image/{}'.format(image), 'rb').read()
  data = mx.image.imdecode(image_string, flag=1)
  plt.imshow(data.asnumpy())
  plt.savefig('{}_original.png'.format(image_name))

  cast = mx.image.CastAug()
  data = cast(data)
  saturation = mx.image.SaturationJitterAug(saturation=0.3)
  saturation_data = saturation(data)
  saturation_data = mx.nd.Cast(saturation_data, dtype='uint8')
  plt.imshow(saturation_data.asnumpy())
  plt.savefig('{}_saturation.png'.format(image_name))

输出结果:
在这里插入图片描述
在这里插入图片描述

你可能感兴趣的:(MXNet之数据读取与增强)