深度学习做图像相关的内容时候,数据集增强是常用并且十分有效的手段,可以有效的对口过拟合以及提高模型的准确率,针对不同的问题有时候需要特定的方式对数据进行变换。Mxnet已经内置了一些常用的增强手段,例如randomcrop,mirror,颜色抖动等。
但是,当需要具体的针对性的数据增强的时候,就需要自己写一个augmenter,对此 mxnet还是比较简单的和灵活的。
最方便的方式当然是在python层面继承Augmenter:
直接上代码了,过程还是很清晰的,继承augmenter, 并实现方法即可
代码中包括rotate 和noise增强, 基于opencv, 由于noise 增强大部分工作是python层面做的,速度比较慢,有待改进。
调用方式是和调用mxnet的原有方法是一样的,比如代码文件是MyAugmentation.py。
CAUTION: 这里提醒大家一下mxnet里面用图像增强代码里的一个小坑,(也不算坑吧,就是一个用法问题)。在做resize和crop的时候一般直接用就可以,但是做colorjitter或者类似的对图像的数据进行处理,需要先调用mx.image.CastAug()就是数据类型转换,不然会报错。
import MyAudmentation
taug_list_train=[
mx.image.ForceResizeAug(size=(shape_,shape_)),
mx.image.RandomCropAug((shape_,shape_)),
mx.image.HorizontalFlipAug(0.5),
mx.image.CastAug(),
##################!!!!!!!!caution
mx.image.ColorJitterAug(0.0, 0.1, 0.1),
mx.image.HueJitterAug(0.5),
mx.image.LightingAug(0.1, eigval, eigvec),
#####调用旋转增强旋转30度,0.5的概率
MyAugmentation.RandomRotateAug(30,0.5)
]
train_iter = mx.image.ImageIter(batch_size=batch_size,
data_shape=shape,
label_width=1,
aug_list=aug_list_train,
shuffle=True,
path_root='',
path_imglist='/you/path/train.lst'
)
import cv2
import mxnet as mx
from mxnet.image import Augmenter
import random
import numpy as np
#######################实现对应的图像处理过程供调用
def rotate(src, angle, center=None, scale=1.0):
image = src.asnumpy()
(h, w) = image.shape[:2]
# set the center point as the rotate center by default
if center is None:
center = (w / 2, h / 2)
# opencv to
M = cv2.getRotationMatrix2D(center, angle, scale)
rotated = cv2.warpAffine(image, M, (w, h))
rotated = mx.nd.array(rotated,dtype=np.uint8)
return rotated
def SaltAndPepper(src,percet):
###it is a very slow mothed, not recommended to use it
Salted=src
image=int(percet*src.shape[0]*src.shape[1])
for i in range(image):
randX=random.randint(0,src.shape[0]-1)
randY=random.randint(0,src.shape[1]-1)
if random.randint(0,1)==0:
Salted[randX,randY]=0.
else:
Salted[randX,randY]=255.
return Salted
#######继承Augmenter,并实现两个方法即可
#####################
class RandomRotateAug(Augmenter):
"""Make randomrotate.
Parameters
----------
angel : float or int the max angel to rotate
p : the possibility the img be rotated
"""
def __init__(self, angel, possibility):
super(RandomRotateAug, self).__init__(angel=angel)
self.maxangel = angel
self.p=possibility
def __call__(self, src):
"""Augmenter body"""
#return resize_short(src, self.size, self.interp)
a = random.random()
if a > self.p:
return src
else:
angle=random.randint(-self.maxangel,self.maxangel)
return rotate(src,angle)
class RandomNoiseAug(Augmenter):
"""Make randomrotate.
Parameters
----------
percet : how much should the img be noised
p : the possibility the img be noised
"""
def __init__(self, percet,possibility):
super(RandomNoiseAug, self).__init__(percet=percet)
self.percet = percet
self.p=possibility
def __call__(self, src):
"""Augmenter body"""
#return resize_short(src, self.size, self.interp)
a = random.random()
if a > self.p:
return src
else:
return SaltAndPepper(src,self.percet)
后续逐渐会有构造customed dataiterator,以及customed operator的介绍,如有错误请指正,并请谅解:)