语义分割数据集的扩充

语义分割数据集

语义分割数据集的mask图片并不是8bit的灰度图,而是png图片特有的通过调色盘去设置颜色的单通道8bit彩色图片,通过调色盘对应的值来显示对应的颜色,并且mask对应的值也对应着类别的标签。
语义分割数据集的扩充_第1张图片
当通过cv2去对图片进行扩充的时候,cv2的图片读取方式的原因,会将图片读取成3通道的彩图,并且保存时会保存成24bit的图片(即3通道8bit),所以语义分割数据集的mask最好不要用cv2进行读取处理,本文方法使用的是PIL库的Image方法对mask图片进行读取,并且使用PIL库的中的方法对图片进行操作扩充。参考文章:https://blog.csdn.net/qq_20852429/article/details/79137777

# -*- coding:utf-8 -*-
"""数据增强
   1. 翻转变换 flip
   2. 随机修剪 random crop
   3. 色彩抖动 color jittering
   4. 平移变换 shift
   5. 尺度变换 scale
   6. 对比度变换 contrast
   7. 噪声扰动 noise
   8. 旋转变换/反射变换 Rotation/reflection
"""

from PIL import Image, ImageEnhance, ImageOps, ImageFile
import numpy as np
import random
import threading, os, time
import logging

logger = logging.getLogger(__name__)
ImageFile.LOAD_TRUNCATED_IMAGES = True


class DataAugmentation:
    """
    包含数据增强的7种方式
    """

    def __init__(self):
        pass

    @staticmethod
    def openImage(image):
        return Image.open(image, mode="r")

    @staticmethod
    def randomRotation(image, label, mode=Image.BICUBIC):
        """
         对图像进行随机任意角度(0~360度)旋转
        :param mode 邻近插值,双线性插值,双三次B样条插值(default)
        :param image PIL的图像image
        :return: 旋转转之后的图像
        """
        random_angle = np.random.randint(1, 360)
        random_trans = np.random.randint(0, 6)
        # label = label.rotate(random_angle, Image.NEAREST).convert(mode="P", )
        # return image.rotate(random_angle, mode), label.rotate(random_angle, Image.NEAREST)
        return image.transpose(random_trans), label.transpose(random_trans)
        # return image.rotate(random_angle, mode), label

    # 暂时未使用这个函数
    @staticmethod
    def randomCrop(image, label):
        """
        对图像随意剪切,考虑到图像大小范围(68,68),使用一个一个大于(36*36)的窗口进行截图
        :param image: PIL的图像image
        :return: 剪切之后的图像
        """
        image_width = image.size[0]
        image_height = image.size[1]
        crop_win_size = np.random.randint(300, 500)
        random_region = (
            (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1,
            (image_height + crop_win_size) >> 1)
        return image.crop(random_region), label.crop(random_region)

    @staticmethod
    def randomColor(image, label):
        """
        对图像进行颜色抖动
        :param image: PIL的图像image
        :return: 有颜色色差的图像image
        """
        random_factor = np.random.randint(0, 31) / 10.  # 随机因子
        color_image = ImageEnhance.Color(image).enhance(random_factor)  # 调整图像的饱和度
        random_factor = np.random.randint(10, 21) / 10.  # 随机因子
        brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)  # 调整图像的亮度
        random_factor = np.random.randint(10, 21) / 10.  # 随机因1子
        contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)  # 调整图像对比度
        random_factor = np.random.randint(0, 31) / 10.  # 随机因子
        return ImageEnhance.Sharpness(contrast_image).enhance(random_factor), label  # 调整图像锐度

    @staticmethod
    def saveImage(image, path):
        image.save(path)


def makeDir(path):
    if not os.path.exists(path):
        if not os.path.isfile(path):
            # os.mkdir(path)
            os.makedirs(path)
        return 0
    else:
        return 1


def imageOps(func_name, image, label, img_des_path, label_des_path, img_file_name, label_file_name, times=5):
    funcMap = {"randomRotation": DataAugmentation.randomRotation,
               "randomCrop": DataAugmentation.randomCrop,
               "randomColor": DataAugmentation.randomColor,
               # "randomGaussian": DataAugmentation.randomGaussian
               }
    if funcMap.get(func_name) is None:
        logger.error("%s is not exist", func_name)
        return -1

    for _i in range(0, times, 1):
        new_image, new_label = funcMap[func_name](image, label)
        DataAugmentation.saveImage(new_image, os.path.join(img_des_path, func_name + str(_i) + img_file_name))
        DataAugmentation.saveImage(new_label, os.path.join(label_des_path, func_name + str(_i) + label_file_name))


opsList = {"randomRotation", "randomCrop", "randomColor"}
# opsList = {"randomRotation"}


def threadOPS(img_path, new_img_path, label_path, new_label_path):
    """
    多线程处理事务
    :param src_path: 资源文件
    :param des_path: 目的地文件
    :return:
    """
    # img path
    if os.path.isdir(img_path):
        img_names = os.listdir(img_path)
    else:
        img_names = [img_path]

    # label path
    if os.path.isdir(label_path):
        label_names = os.listdir(label_path)
    else:
        label_names = [label_path]

    img_num = 0
    label_num = 0

    # img num
    for img_name in img_names:
        tmp_img_name = os.path.join(img_path, img_name)
        if os.path.isdir(tmp_img_name):
            print('contain file folder')
            exit()
        else:
            img_num = img_num + 1;
    # label num
    for label_name in label_names:
        tmp_label_name = os.path.join(label_path, label_name)
        if os.path.isdir(tmp_label_name):
            print('contain file folder')
            exit()
        else:
            label_num = label_num + 1

    if img_num != label_num:
        print('the num of img and label is not equl')
        exit()
    else:
        num = img_num

    for i in range(num):
        img_name = img_names[i]
        label_name = label_names[i]

        tmp_img_name = os.path.join(img_path, img_name)
        tmp_label_name = os.path.join(label_path, label_name)

        # 读取文件并进行操作
        image = DataAugmentation.openImage(tmp_img_name)
        label = DataAugmentation.openImage(tmp_label_name)

        threadImage = [0] * 3
        _index = 0
        for ops_name in opsList:
            threadImage[_index] = threading.Thread(target=imageOps,
                                                   args=(ops_name, image, label, new_img_path, new_label_path, img_name,
                                                         label_name))
            threadImage[_index].start()
            _index += 1
            time.sleep(0.2)


if __name__ == '__main__':
    threadOPS(r"D:\datasets\EddyDataset\JPEGImages",    # 原image
              r"D:\datasets\Augmentor\JPEGImages",		# 扩充的image
              r"D:\datasets\EddyDataset\SegmentationClass",   # 原mask
              r"D:\datasets\Augmentor\SegmentationClass")		# 扩充的mask

原文章使用多线程以及Image方法同时对image和mask同时扩充,但是高斯方法有些问题,这里就把高斯方法去掉,还有一个地方不同的就是在对图像旋转的时候,扩充的mask也会无法生成,仅仅只是生成黑色的图片,经过实验发现transpose方法进行反转之后的mask图像可以保存,所以将原有的rotate函数改成了transpose函数来对数据进行扩充。

由于生成的图片名称为扩充方法加数值的形式,本文额外增加了重命名代码:

import os


def reName(img_path, label_path, img_save_path, label_save_path, img_name, img_type, label_type, name):
    if not os.path.exists(img_save_path):
        os.makedirs(img_save_path)
    if not os.path.exists(label_save_path):
        os.makedirs(label_save_path)
    os.renames(os.path.join(img_path, img_name.split('.')[0] + img_type), os.path.join(img_save_path, name + img_type))
    os.renames(os.path.join(label_path, img_name.split('.')[0] + label_type), os.path.join(label_save_path, name + label_type))
    print(f"save file:{name + img_type}")

if __name__ == "__main__":
    img_path = r"D:\datasets\Augment\JPEGImages"
    png_path = r"D:\datasets\Augment\SegmentationClass"
    img_save_path = r"D:\datasets\Augment\rename\JPEGImages"
    png_save_path = r"D:\datasets\Augment\rename\SegmentationClass"   # 重命名的image,mask保存的路径和源路径相同即可
    img_list = os.listdir(img_path)
    num = 0
    for name in img_list:
        name_ = str(num).zfill(6)
        reName(img_path, png_path, img_save_path, png_save_path, name, ".jpg", ".png", name_)
        num += 1

你可能感兴趣的:(数据处理,计算机视觉,python,深度学习)