【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战

最近在学习数据增强方面的东西,简单做个记录

首先需要强调的是,数据增强是目标检测流程中的一个过程,通常是在对数据集完成打标签之后,在划分数据集之前,为了增大数据集的数量,获取更多的特征,采用的一种方式。所以,在实战的过程中,不仅仅要对图像进行操作,还要对已经打好的标签(VOC数据集的xml文件)进行同样的对应操作

随机抖动,mosaic,mixup三种方法参考的是GitHub上大佬的代码,链接如下:

bubbliiiing/object-detection-augmentation: 这里面存放了一些目标检测算法的数据增强方法。如mosaic、mixup。 (github.com)

rotate,HSV,的参考

DataAugmentation_ForObjectDetect/rotated.py at master · DLLXW/DataAugmentation_ForObjectDetect · GitHub

cutout方法是看了两篇博主的博客,然后缝合出来的,链接如下:

(24条消息) 数据增强实测之cutout_一个菜鸟的奋斗的博客-CSDN博客_cutout数据增强

(24条消息) Cutout一种新的正则化方法_点PY的博客-CSDN博客_cutout方法

下面开始正题

首先需要准备好需要数据加强的原图(我使用的是jpeg格式)以及与原图一一对应的标签文件,命名必须一样

如何使用labelimg制作VOC数据集可以看我之前的贴子

【目标检测实战学习】从零开始制作并训练自己的VOC数据集,并使用Retinanet进行目标检测_Bill~QAQ~的博客-CSDN博客

1.旋转(rotate)

就是将图片进行一定角度的旋转,当旋转角度为180度或者360度时,图像的大小不会发生变化,否则图像就会呈现内切矩形的大小:

原图

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第1张图片

旋转200度后

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第2张图片

 代码如下:

import cv2
import math
import numpy as np
import os
import glob
import json
import shutil
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import ElementTree, Element

def getRotatedImg(Pi_angle,img_path,img_write_path):
    img = cv2.imread(img_path)
    rows, cols = img.shape[:2]
    a, b = cols / 2, rows / 2
    M = cv2.getRotationMatrix2D((a, b), angle, 1)
    rotated_img = cv2.warpAffine(img, M, (cols, rows))  # 旋转后的图像保持大小不变
    cv2.imwrite(img_write_path,rotated_img)
    return a,b

def getRotatedAnno(Pi_angle,a,b,anno_path,anno_write_path):
    tree = ET.parse(anno_path)
    root = tree.getroot()
    objects = root.findall("object")
    for obj in objects:
        bbox = obj.find('bndbox')
        x1 = float(bbox.find('xmin').text) - 1
        y1 = float(bbox.find('ymin').text) - 1
        x2 = float(bbox.find('xmax').text) - 1
        y2 = float(bbox.find('ymax').text) - 1

        x3=x1
        y3=y2
        x4=x2
        y4=y1

        X1 = (x1 - a) * math.cos(Pi_angle) - (y1 - b) * math.sin(Pi_angle) + a
        Y1 = (x1 - a) * math.sin(Pi_angle) + (y1 - b) * math.cos(Pi_angle) + b

        X2 = (x2 - a) * math.cos(Pi_angle) - (y2 - b) * math.sin(Pi_angle) + a
        Y2 = (x2 - a) * math.sin(Pi_angle) + (y2 - b) * math.cos(Pi_angle) + b

        X3 = (x3 - a) * math.cos(Pi_angle) - (y3 - b) * math.sin(Pi_angle) + a
        Y3 = (x3 - a) * math.sin(Pi_angle) + (y3 - b) * math.cos(Pi_angle) + b

        X4 = (x4 - a) * math.cos(Pi_angle) - (y4 - b) * math.sin(Pi_angle) + a
        Y4 = (x4 - a) * math.sin(Pi_angle) + (y4 - b) * math.cos(Pi_angle) + b

        X_MIN=min(X1,X2,X3,X4)
        X_MAX = max(X1, X2, X3, X4)
        Y_MIN = min(Y1, Y2, Y3, Y4)
        Y_MAX = max(Y1, Y2, Y3, Y4)

        bbox.find('xmin').text=str(int(X_MIN))
        bbox.find('ymin').text=str(int(Y_MIN))
        bbox.find('xmax').text=str(int(X_MAX))
        bbox.find('ymax').text=str(int(Y_MAX))

    tree.write(anno_write_path)  # 保存修改后的XML文件

def rotate(angle,img_dir,anno_dir,img_write_dir,anno_write_dir):
    if not os.path.exists(img_write_dir):
        os.makedirs(img_write_dir)

    if not os.path.exists(anno_write_dir):
        os.makedirs(anno_write_dir)

    Pi_angle = -angle * math.pi / 180.0  # 弧度制,后面旋转坐标需要用到,注意负号!!!
    img_names=os.listdir(img_dir)
    for img_name in img_names:
        img_path=os.path.join(img_dir,img_name)
        img_write_path=os.path.join(img_write_dir,img_name[:-4]+'.jpg')
        #
        anno_path=os.path.join(anno_dir,img_name[:-4]+'.xml')
        anno_write_path = os.path.join(anno_write_dir, img_name[:-4]+'.xml')
        #
        a,b=getRotatedImg(Pi_angle,img_path,img_write_path)
        getRotatedAnno(Pi_angle,a,b,anno_path,anno_write_path)

angle=200
img_dir='VOCdevkit_Origin/VOC2007/JPEGImages'
anno_dir='VOCdevkit_Origin/VOC2007/Annotations'
img_write_dir='H:\object-detection-augmentation\VOCdevkit\VOC2007\JPEGImages'
anno_write_dir='H:\object-detection-augmentation\VOCdevkit\VOC2007\Annotations'

rotate(angle,img_dir,anno_dir,img_write_dir,anno_write_dir)

2.Cutout

cutout就是在图像中随机选取一个或多个区域,将其裁剪掉(像素设为0或者其他)

原图:

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第3张图片

 cutout之后:

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第4张图片

需要注意的是,因为cutout方法不会改变xml标签中bbox的位置信息,所以该方法不需要对标签文件xml进行修改

操作代码如下:

import cv2
from torchvision import transforms
import os.path
import glob
import torch
import numpy as np

class Cutout(object):
    def __init__(self, n_holes=30, length=25):
        self.n_holes = n_holes
        self.length = length
    def __call__(self, img):
        h = img.size(1)
        w = img.size(2)
        mask = np.ones((h, w), np.float32)
        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)
            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)
            mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask
        return img

def cutout_img(inputfile,outputfile):
    src=cv2.imread(inputfile,cv2.IMREAD_UNCHANGED)
    src=transforms.ToTensor()(src)
    #每张图片中cutout的大小
    cut=Cutout(length=25,n_holes=50)
    src=cut(src)
    #写入
    src=src.mul(255).byte()
    src=src.numpy().transpose((1,2,0))
    cv2.imwrite(os.path.join(outputfile,os.path.basename(inputfile)),src)

#需要进行预处理的图片路径
for inputfile in glob.glob(r'H:\object-detection-augmentation\VOCdevkit_Origin\VOC2007\JPEGImages\*jpeg'):
#保存的路径,这里设置为原文件的路径,直接覆盖
    cutout_img(inputfile,r'H:\object-detection-augmentation\VOCdevkit\VOC2007\JPEGImages')


讲几个部分的参数

 def __init__(self, n_holes=30, length=25):

n_holes为需要生成多少个切割方块

length为每一个切割方块的大小(length*length)

mask[y1: y2, x1: x2] = 0.

0.就是代表所有的切割方块都是黑色,调成其他值就会有其他的颜色(155就会呈现出类似马赛克的颜色)

在最后的路径中

#需要进行预处理的图片路径
for inputfile in glob.glob(r'H:\object-detection-augmentation\VOCdevkit_Origin\VOC2007\JPEGImages\*jpg'):
#保存的路径,这里设置为原文件的路径,直接覆盖
    cutout_img(inputfile,r'H:\object-detection-augmentation\VOCdevkit\VOC2007\JPEGImages')

glob.glob代表一个一个遍历这个文件夹下的图片,最后这个路径要加上\*jpg,代表读取JPEGImages下所有结尾是jpeg的文件

3.mixup

mixup就是图像融合,在数据集中随机选取两张图片,进行图像融合,同时带着他们原始图像的标签位置变动

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第5张图片

代码如下:

import os
from random import sample

import numpy as np
from PIL import Image, ImageDraw

from utils.random_data import get_random_data, get_random_data_with_MixUp
from utils.utils import convert_annotation, get_classes

#-----------------------------------------------------------------------------------#
#   Origin_VOCdevkit_path   原始标签所在的路径
#   Out_VOCdevkit_path      输出标签所在的路径
#                                   处理后的标签为灰度图,如果设置的值太小会看不见具体情况。
#-----------------------------------------------------------------------------------#
Origin_VOCdevkit_path   = "VOCdevkit_Origin"
Out_VOCdevkit_path      = "VOCdevkit"
#-----------------------------------------------------------------------------------#
#   Out_Num                 利用mixup生成多少组图片
#   input_shape             生成的图片大小
#-----------------------------------------------------------------------------------#
Out_Num                 = 50
input_shape             = [640, 640]

#-----------------------------------------------------------------------------------#
#   下面定义了xml里面的组成模块,无需改动。
#-----------------------------------------------------------------------------------#
headstr = """\

    VOC
    %s
    
        My Database
        COCO
        flickr
        NULL
    
    
        NULL
        company
    
    
        %d
        %d
        %d
    
    0
"""

objstr = """\
    
        %s
        Unspecified
        0
        0
        
            %d
            %d
            %d
            %d
        
    
"""
    
tailstr = '''\

'''
if __name__ == "__main__":
    Origin_JPEGImages_path  = os.path.join(Origin_VOCdevkit_path, "VOC2007/JPEGImages")
    Origin_Annotations_path = os.path.join(Origin_VOCdevkit_path, "VOC2007/Annotations")
    
    Out_JPEGImages_path  = os.path.join(Out_VOCdevkit_path, "VOC2007/JPEGImages")
    Out_Annotations_path = os.path.join(Out_VOCdevkit_path, "VOC2007/Annotations")
    
    if not os.path.exists(Out_JPEGImages_path):
        os.makedirs(Out_JPEGImages_path)
    if not os.path.exists(Out_Annotations_path):
        os.makedirs(Out_Annotations_path)
    #---------------------------#
    #   遍历标签并赋值
    #---------------------------#
    xml_names = os.listdir(Origin_Annotations_path)

    def write_xml(anno_path, jpg_pth, head, input_shape, boxes, unique_labels, tail):
        f = open(anno_path, "w")
        f.write(head%(jpg_pth, input_shape[0], input_shape[1], 3))
        for i, box in enumerate(boxes):
            f.write(objstr%(str(unique_labels[int(box[4])]), box[0], box[1], box[2], box[3]))
        f.write(tail)
    #########以上部分无需改动###########################################################################################
    #------------------------------#
    #   循环生成xml和jpg
    #------------------------------#
    for index in range(Out_Num):
        #------------------------------#
        #   获取两个图像与标签
        #------------------------------#
        sample_xmls = sample(xml_names, 2)
        unique_labels = get_classes(sample_xmls, Origin_Annotations_path)

        jpg_name_1  = os.path.join(Origin_JPEGImages_path, os.path.splitext(sample_xmls[0])[0] + '.jpg')
        jpg_name_2  = os.path.join(Origin_JPEGImages_path, os.path.splitext(sample_xmls[1])[0] + '.jpg')
        xml_name_1  = os.path.join(Origin_Annotations_path, sample_xmls[0])
        xml_name_2  = os.path.join(Origin_Annotations_path, sample_xmls[1])
            
        line_1 = convert_annotation(jpg_name_1, xml_name_1, unique_labels)
        line_2 = convert_annotation(jpg_name_2, xml_name_2, unique_labels)
        
        #------------------------------#
        #   各自数据增强
        #------------------------------#
        image_1, box_1  = get_random_data(line_1, input_shape) 
        image_2, box_2  = get_random_data(line_2, input_shape) 
        
        #------------------------------#
        #   合并mixup
        #------------------------------#
        image_data, box_data = get_random_data_with_MixUp(image_1, box_1, image_2, box_2)
        
        img = Image.fromarray(image_data.astype(np.uint8))
        img.save(os.path.join(Out_JPEGImages_path, str(index) + '.jpg'))
        write_xml(os.path.join(Out_Annotations_path, str(index) + '.xml'), os.path.join(Out_JPEGImages_path, str(index) + '.jpg'), \
                    headstr, input_shape, box_data, unique_labels, tailstr)

 可以注释修改的部分都已经注释在里面了

4.mosaic

mosaic有点类似于拼图,在数据集中随机选取四张图片,进行随机抖动,然后将四张拼成一张图

 【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第6张图片

代码如下:

import os
from random import sample

import numpy as np
from PIL import Image, ImageDraw

from utils.random_data import get_random_data, get_random_data_with_Mosaic
from utils.utils import convert_annotation, get_classes

#-----------------------------------------------------------------------------------#
#   Origin_VOCdevkit_path   原始标签所在的路径
#   Out_VOCdevkit_path      输出标签所在的路径
#                                   处理后的标签为灰度图,如果设置的值太小会看不见具体情况。
#-----------------------------------------------------------------------------------#
Origin_VOCdevkit_path   = "VOCdevkit_Origin"
Out_VOCdevkit_path      = "VOCdevkit"
#-----------------------------------------------------------------------------------#
#   Out_Num                 利用mixup生成多少组图片
#   input_shape             生成的图片大小
#-----------------------------------------------------------------------------------#
Out_Num                 = 50
input_shape             = [640, 640]

#-----------------------------------------------------------------------------------#
#   下面定义了xml里面的组成模块,无需改动。
#-----------------------------------------------------------------------------------#
headstr = """\

    VOC
    %s
    
        My Database
        COCO
        flickr
        NULL
    
    
        NULL
        company
    
    
        %d
        %d
        %d
    
    0
"""

objstr = """\
    
        %s
        Unspecified
        0
        0
        
            %d
            %d
            %d
            %d
        
    
"""
    
tailstr = '''\

'''
if __name__ == "__main__":
    Origin_JPEGImages_path  = os.path.join(Origin_VOCdevkit_path, "VOC2007/JPEGImages")
    Origin_Annotations_path = os.path.join(Origin_VOCdevkit_path, "VOC2007/Annotations")
    
    Out_JPEGImages_path  = os.path.join(Out_VOCdevkit_path, "VOC2007/JPEGImages")
    Out_Annotations_path = os.path.join(Out_VOCdevkit_path, "VOC2007/Annotations")
    
    if not os.path.exists(Out_JPEGImages_path):
        os.makedirs(Out_JPEGImages_path)
    if not os.path.exists(Out_Annotations_path):
        os.makedirs(Out_Annotations_path)
    #---------------------------#
    #   遍历标签并赋值
    #---------------------------#
    xml_names = os.listdir(Origin_Annotations_path)

    def write_xml(anno_path, jpg_pth, head, input_shape, boxes, unique_labels, tail):
        f = open(anno_path, "w")
        f.write(head%(jpg_pth, input_shape[0], input_shape[1], 3))
        for i, box in enumerate(boxes):
            f.write(objstr%(str(unique_labels[int(box[4])]), box[0], box[1], box[2], box[3]))
        f.write(tail)
    #########以上部分无需改动###########################################################################################
    #------------------------------#
    #   循环生成xml和jpg
    #------------------------------#
    for index in range(Out_Num):
        #------------------------------#
        #   获取4个图像与标签
        #------------------------------#
        sample_xmls     = sample(xml_names, 4)
        unique_labels   = get_classes(sample_xmls, Origin_Annotations_path)

        annotation_line = []
        for xml in sample_xmls:
            line = convert_annotation(os.path.join(Origin_JPEGImages_path, os.path.splitext(xml)[0] + '.jpeg'), os.path.join(Origin_Annotations_path, xml), unique_labels)
            annotation_line.append(line)
        #------------------------------#
        #   合并mosaic
        #------------------------------#
        image_data, box_data = get_random_data_with_Mosaic(annotation_line, input_shape)
        img = Image.fromarray(image_data.astype(np.uint8))
        img.save(os.path.join(Out_JPEGImages_path, str(index) + '.jpeg'))
        write_xml(os.path.join(Out_Annotations_path, str(index) + '.xml'), os.path.join(Out_JPEGImages_path, str(index) + '.jpeg'), \
                    headstr, input_shape, box_data, unique_labels, tailstr)

5.亮度,对比度调整

修改图像的一系列HSV参数,改变图像的亮度,色相,饱和度等数据

原图:

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第7张图片

进行参数调整后:

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第8张图片

import cv2
import math
import numpy as np
import os
import glob
import json
import shutil
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import ElementTree, Element

def getColorImg(alpha,beta,img_path,img_write_path):
    img = cv2.imread(img_path)
    colored_img = np.uint8(np.clip((alpha * img + beta), 0, 255))
    cv2.imwrite(img_write_path,colored_img)

def getColorAnno(anno_path,anno_write_path):
    tree = ET.parse(anno_path)
    tree.write(anno_write_path)  # 保存修改后的XML文件

def color(alpha,beta,img_dir,anno_dir,img_write_dir,anno_write_dir):
    if not os.path.exists(img_write_dir):
        os.makedirs(img_write_dir)

    if not os.path.exists(anno_write_dir):
        os.makedirs(anno_write_dir)
    img_names=os.listdir(img_dir)
    for img_name in img_names:
        img_path=os.path.join(img_dir,img_name)
        img_write_path=os.path.join(img_write_dir,img_name[:-4]+'color'+str(int(alpha*10))+'.jpg')
        #
        anno_path=os.path.join(anno_dir,img_name[:-4]+'.xml')
        anno_write_path = os.path.join(anno_write_dir, img_name[:-4]+'color'+str(int(alpha*10))+'.xml')
        #
        getColorImg(alpha,beta,img_path,img_write_path)
        getColorAnno(anno_path,anno_write_path)

alphas=[0.3,0.5,1.2,1.6]
beta=10
img_dir='VOCdevkit_Origin/VOC2007/JPEGImages'
anno_dir='VOCdevkit_Origin/VOC2007/Annotations'
img_write_dir='H:\object-detection-augmentation\VOCdevkit\VOC2007\JPEGImages'
anno_write_dir='H:\object-detection-augmentation\VOCdevkit\VOC2007\Annotations'
for alpha in alphas:
    color(alpha,beta,img_dir,anno_dir,img_write_dir,anno_write_dir)

 6.随机抖动

随机抖动是在对图像进行HSV变换之后,加上对图像的随机resize,裁剪,平移,旋转

原图:

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第9张图片

随机抖动之后:

【目标检测实战学习】数据增强的几种方法:cutout,mixup,mosaic,rotate,HSV,随机抖动实战_第10张图片

import os
from random import sample

import numpy as np
from PIL import Image, ImageDraw

from utils.random_data import get_random_data, get_random_data_with_MixUp
from utils.utils import convert_annotation, get_classes

#-----------------------------------------------------------------------------------#
#   Origin_VOCdevkit_path   原始标签所在的路径
#   Out_VOCdevkit_path      输出标签所在的路径
#                                   处理后的标签为灰度图,如果设置的值太小会看不见具体情况。
#-----------------------------------------------------------------------------------#
Origin_VOCdevkit_path   = "VOCdevkit_Origin"
Out_VOCdevkit_path      = "VOCdevkit"
#-----------------------------------------------------------------------------------#
#   Out_Num                 生成多少组图片
#   input_shape             生成的图片大小
#-----------------------------------------------------------------------------------#
Out_Num                 = 30
input_shape             = [640, 640]

#-----------------------------------------------------------------------------------#
#   下面定义了xml里面的组成模块,无需改动。
#-----------------------------------------------------------------------------------#
headstr = """\

    VOC
    %s
    
        My Database
        COCO
        flickr
        NULL
    
    
        NULL
        company
    
    
        %d
        %d
        %d
    
    0
"""

objstr = """\
    
        %s
        Unspecified
        0
        0
        
            %d
            %d
            %d
            %d
        
    
"""
    
tailstr = '''\

'''
if __name__ == "__main__":
    Origin_JPEGImages_path  = os.path.join(Origin_VOCdevkit_path, "VOC2007/JPEGImages")
    Origin_Annotations_path = os.path.join(Origin_VOCdevkit_path, "VOC2007/Annotations")
    
    Out_JPEGImages_path  = os.path.join(Out_VOCdevkit_path, "VOC2007/JPEGImages")
    Out_Annotations_path = os.path.join(Out_VOCdevkit_path, "VOC2007/Annotations")
    
    if not os.path.exists(Out_JPEGImages_path):
        os.makedirs(Out_JPEGImages_path)
    if not os.path.exists(Out_Annotations_path):
        os.makedirs(Out_Annotations_path)
    #---------------------------#
    #   遍历标签并赋值
    #---------------------------#
    xml_names = os.listdir(Origin_Annotations_path)

    def write_xml(anno_path, jpg_pth, head, input_shape, boxes, unique_labels, tail):
        f = open(anno_path, "w")
        f.write(head%(jpg_pth, input_shape[0], input_shape[1], 3))
        for i, box in enumerate(boxes):
            f.write(objstr%(str(unique_labels[int(box[4])]), box[0], box[1], box[2], box[3]))
        f.write(tail)
    #########以上部分无需改动###########################################################################################
    #------------------------------#
    #   循环生成xml和jpg
    #------------------------------#
    for index in range(Out_Num):
        #------------------------------#
        #   获取一个图像与标签
        #------------------------------#
        sample_xmls     = sample(xml_names, 1)
        unique_labels   = get_classes(sample_xmls, Origin_Annotations_path)
        
        jpg_name  = os.path.join(Origin_JPEGImages_path, os.path.splitext(sample_xmls[0])[0] + '.jpg')
        xml_name  = os.path.join(Origin_Annotations_path, sample_xmls[0])
            
        line = convert_annotation(jpg_name, xml_name, unique_labels)
        
        #------------------------------#
        #   各自数据增强
        #------------------------------#
        image_data, box_data  = get_random_data(line, input_shape) 
        
        img = Image.fromarray(image_data.astype(np.uint8))
        img.save(os.path.join(Out_JPEGImages_path, str(index) + '.jpg'))
        write_xml(os.path.join(Out_Annotations_path, str(index) + '.xml'), os.path.join(Out_JPEGImages_path, str(index) + '.jpg'), \
                    headstr, input_shape, box_data, unique_labels, tailstr)

你可能感兴趣的:(目标检测,目标检测,学习,人工智能)