mxnet实现自己的图像数据集增强方法

        深度学习做图像相关的内容时候,数据集增强是常用并且十分有效的手段,可以有效的对口过拟合以及提高模型的准确率,针对不同的问题有时候需要特定的方式对数据进行变换。Mxnet已经内置了一些常用的增强手段,例如randomcrop,mirror,颜色抖动等。

        但是,当需要具体的针对性的数据增强的时候,就需要自己写一个augmenter,对此 mxnet还是比较简单的和灵活的。

最方便的方式当然是在python层面继承Augmenter:


        直接上代码了,过程还是很清晰的,继承augmenter, 并实现方法即可

        代码中包括rotate 和noise增强, 基于opencv, 由于noise 增强大部分工作是python层面做的,速度比较慢,有待改进。

        调用方式是和调用mxnet的原有方法是一样的,比如代码文件是MyAugmentation.py。

        CAUTION: 这里提醒大家一下mxnet里面用图像增强代码里的一个小坑,(也不算坑吧,就是一个用法问题)。在做resize和crop的时候一般直接用就可以,但是做colorjitter或者类似的对图像的数据进行处理,需要先调用mx.image.CastAug()就是数据类型转换,不然会报错。

import MyAudmentation

taug_list_train=[ 
                mx.image.ForceResizeAug(size=(shape_,shape_)), 
                mx.image.RandomCropAug((shape_,shape_)), 
                mx.image.HorizontalFlipAug(0.5), 
                mx.image.CastAug(),
                ##################!!!!!!!!caution
                mx.image.ColorJitterAug(0.0, 0.1, 0.1),
                mx.image.HueJitterAug(0.5), 
                mx.image.LightingAug(0.1, eigval, eigvec),
                #####调用旋转增强旋转30度,0.5的概率
                MyAugmentation.RandomRotateAug(30,0.5)
                ]
train_iter = mx.image.ImageIter(batch_size=batch_size,
                                    data_shape=shape,
                                    label_width=1,
                                    aug_list=aug_list_train,
                                    shuffle=True,
                                    path_root='',
                                    path_imglist='/you/path/train.lst'
                                    )



MyAugmentation.py
import cv2
import mxnet as mx
from mxnet.image import  Augmenter
import random
import numpy as np
#######################实现对应的图像处理过程供调用
def rotate(src, angle, center=None, scale=1.0):
    image = src.asnumpy()
    (h, w) = image.shape[:2]
    # set the center point as the rotate center by default
    if center is None:
        center = (w / 2, h / 2)
    # opencv to 
    M = cv2.getRotationMatrix2D(center, angle, scale)
    rotated = cv2.warpAffine(image, M, (w, h))
    rotated = mx.nd.array(rotated,dtype=np.uint8)
    
    return rotated

def SaltAndPepper(src,percet):
    ###it is a very slow mothed, not recommended  to use it
    Salted=src
    image=int(percet*src.shape[0]*src.shape[1])
    for i in range(image):
        randX=random.randint(0,src.shape[0]-1)
        randY=random.randint(0,src.shape[1]-1)
        if random.randint(0,1)==0:
            Salted[randX,randY]=0.
        else:
            Salted[randX,randY]=255.
    return Salted



#######继承Augmenter,并实现两个方法即可

#####################
class RandomRotateAug(Augmenter):
    """Make randomrotate.
    Parameters
    ----------
    angel : float or int the max angel to rotate
    p : the possibility the img be rotated
    """
    def __init__(self, angel, possibility):
        super(RandomRotateAug, self).__init__(angel=angel)
        self.maxangel = angel
        self.p=possibility
    def __call__(self, src):
        """Augmenter body"""
        #return resize_short(src, self.size, self.interp)
        a = random.random()
        if a > self.p:
            return src
        else:
            angle=random.randint(-self.maxangel,self.maxangel)
            return rotate(src,angle)



class RandomNoiseAug(Augmenter):
    """Make randomrotate.
    Parameters
    ----------
    percet : how much should the img be noised
    p : the possibility the img be noised
    """
    def __init__(self, percet,possibility):
        super(RandomNoiseAug, self).__init__(percet=percet)
        self.percet = percet
        self.p=possibility
    def __call__(self, src):
        """Augmenter body"""
        #return resize_short(src, self.size, self.interp)
        a = random.random()
        if a > self.p:
            return src
        else:
            return SaltAndPepper(src,self.percet)

后续逐渐会有构造customed dataiterator,以及customed operator的介绍,如有错误请指正,并请谅解:)

你可能感兴趣的:(DL,mxnet)