数据增强操作(旋转、翻转、裁剪、色彩变化、高斯噪声等)

# -*- coding:utf-8 -*-

"""数据增强
   1. 翻转变换 flip
   2. 随机修剪 random crop
   3. 色彩抖动 color jittering
   4. 平移变换 shift
   5. 尺度变换 scale
   6. 对比度变换 contrast
   7. 噪声扰动 noise
   8. 旋转变换/反射变换 Rotation/reflection
"""

from PIL import Image, ImageEnhance, ImageOps, ImageFile
import numpy as np
import random
import threading, os, time
import logging
import math
import shutil

logger = logging.getLogger(__name__)
ImageFile.LOAD_TRUNCATED_IMAGES = True


class DataAugmentation:
    """
    包含数据增强的八种方式
    """

    def __init__(self):
        pass

    @staticmethod
    def openImage(image):
        return Image.open(image, mode="r")

    @staticmethod
    def randomFlip(image, mode=Image.FLIP_LEFT_RIGHT):
        """
        对图像进行上下左右四个方面的随机翻转
        :param image: PIL的图像image
        :param model: 水平或者垂直方向的随机翻转模式,默认右向翻转
        :return: 翻转之后的图像
        """
        #random_model = np.random.randint(0, 2)
        #flip_model = [Image.FLIP_LEFT_RIGHT, Image.FLIP_TOP_BOTTOM]
        #return image.transpose(flip_model[random_model])
        return image.transpose(mode)

    @staticmethod
    def randomShift(image):
    #def randomShift(image, xoffset, yoffset=None):
        """
        对图像进行平移操作
        :param image: PIL的图像image
        :param xoffset: x方向向右平移
        :param yoffset: y方向向下平移
        :return: 翻转之后的图像
        """
        random_xoffset = np.random.randint(0, math.ceil(image.size[0]*0.2))
        random_yoffset = np.random.randint(0, math.ceil(image.size[1]*0.2))
        #return image.offset(xoffset = random_xoffset, yoffset = random_yoffset)
        return image.offset(random_xoffset)

    @staticmethod
    def randomRotation(image, mode=Image.BICUBIC):
        """
         对图像进行随机任意角度(0~360度)旋转
        :param mode 邻近插值,双线性插值,双三次B样条插值(default)
        :param image PIL的图像image
        :return: 旋转转之后的图像
        """
        random_angle = np.random.randint(1, 360)
        return image.rotate(random_angle, mode)

    @staticmethod
    def randomCrop(image):
        """
        对图像随意剪切,裁剪图像大小宽和高的2/3
        :param image: PIL的图像image
        :return: 剪切之后的图像

        """
        image_width = image.size[0]
        image_height = image.size[1]
        crop_image_width = math.ceil(image_width*2/3)
        crop_image_height = math.ceil(image_height*2/3)
        x = np.random.randint(0, image_width - crop_image_width)
        y = np.random.randint(0, image_height - crop_image_height) 
        random_region = (x, y, x + crop_image_width, y + crop_image_height)
        return image.crop(random_region)

    @staticmethod
    def randomColor(image):
        """
        对图像进行颜色抖动
        :param image: PIL的图像image
        :return: 有颜色色差的图像image
        """
        random_factor = np.random.randint(0, 31) / 10.  # 随机因子
        color_image = ImageEnhance.Color(image).enhance(random_factor)  # 调整图像的饱和度
        random_factor = np.random.randint(10, 21) / 10.  # 随机因子
        brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 调整图像的亮度
        random_factor = np.random.randint(10, 21) / 10.  # 随机因1子
        contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 调整图像对比度
        random_factor = np.random.randint(0, 31) / 10.  # 随机因子
        return ImageEnhance.Sharpness(contrast_image).enhance(random_factor)  # 调整图像锐度

    @staticmethod
    def randomGaussian(image, mean=0.2, sigma=0.3):
        """
         对图像进行高斯噪声处理
        :param image:
        :return:
        """

        def gaussianNoisy(im, mean=0.2, sigma=0.3):
            """
            对图像做高斯噪音处理
            :param im: 单通道图像
            :param mean: 偏移量
            :param sigma: 标准差
            :return:
            """
            for _i in range(len(im)):
                im[_i] += random.gauss(mean, sigma)
            return im

        # 将图像转化成数组
        img = np.asarray(image)
        img.flags.writeable = True  # 将数组改为读写模式
        width, height = img.shape[:2]
        try:
            img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma)
            img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma)
            img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma)
            img[:, :, 0] = img_r.reshape([width, height])
            img[:, :, 1] = img_g.reshape([width, height])
            img[:, :, 2] = img_b.reshape([width, height])
        except:
            img = img
        return Image.fromarray(np.uint8(img))

    @staticmethod
    def saveImage(image, path):
        try:
            image.save(path)
        except:
            print('not save img: ', path)
            pass

files = []
def get_files(dir_path):
    global files
    if os.path.exists(dir_path):
        parents = os.listdir(dir_path)
        for parent in parents:
            child = os.path.join(dir_path, parent)
            if os.path.exists(child) and os.path.isfile(child):
               #child = child.split('/')[4:]
               #str_child = '/'.join(child)
               files.append(child)
            elif os.path.isdir(child):
                get_files(child)
        return files
    else:
        return None

if __name__ == '__main__':
    times = 2  #重复次数
    imgs_dir = '/opt/sda/imgData20190322/train'
    new_imgs_dir = '/opt/sda/imgData20190322/train_data_augment'
    #if os.path.exists(new_imgs_dir):
    #    shutil.rmtree(new_imgs_dir)
    funcMap = {"flip": DataAugmentation.randomFlip,
               "rotation": DataAugmentation.randomRotation,
               "crop": DataAugmentation.randomCrop,
               "color": DataAugmentation.randomColor,
               "gaussian": DataAugmentation.randomGaussian
               }
    #funcLists = {"flip", "rotation", "crop", "color", "gaussian"}
    funcLists = {"flip", "rotation", "crop", "gaussian"}
    
    global _index
    imgs_list = get_files(imgs_dir)
    for index_img, img in enumerate(imgs_list):
        if index_img != 0 and index_img % 50 == 0:
            print('now is dealing %d image' % (index_img) )
        tmp_img_dir_list = img.split('/')[:-1]
        tmp_img_dir_list[0:len(new_imgs_dir.split('/'))] = new_imgs_dir.split('/')
        new_img_dir = '/'.join(tmp_img_dir_list)

        if not os.path.exists(new_img_dir):
            os.makedirs(new_img_dir)
        try:
            shutil.copy(img, os.path.join(new_img_dir, img.split('/')[-1]))
        except:
            pass

        img_name = img.split('/')[-1].split('.')[0]
        postfix = img.split('.')[1]   #后缀 
        if postfix.lower() in ['jpg', 'jpeg', 'png', 'bmp']:
            image = DataAugmentation.openImage(img)
            _index = 1
            for func in funcLists:
                if func == 'flip':
                    flip_model = [Image.FLIP_LEFT_RIGHT, Image.FLIP_TOP_BOTTOM]
                    for model_index in range(len(flip_model)):
                        new_image = DataAugmentation.randomFlip(image, flip_model[model_index])
                        img_path = os.path.join(new_img_dir, img_name + '_' + str(_index) + '.' + postfix)
                        DataAugmentation.saveImage(new_image, img_path)
                        _index += 1 
                elif func == 'gaussian':
                   new_image = DataAugmentation.randomGaussian(image)
                   img_path = os.path.join(new_img_dir, img_name + '_' + str(_index) + '.' + postfix)
                   DataAugmentation.saveImage(new_image, img_path)
                   _index += 1 
                else:
                    for _i in range(0, times, 1):
                        new_image = funcMap[func](image)
                        img_path = os.path.join(new_img_dir, img_name + '_' + str(_index) + '.' + postfix)
                        DataAugmentation.saveImage(new_image, img_path)
                        _index += 1
                        

 

你可能感兴趣的:(python,深度学习,caffe)