简单的数据增广(Data Augmentation)(附代码)

简单的数据增广(Data Augmentation)

博主最近做一个小样本的项目,当时时间紧迫,就找了几何变换的数据增广方法。数据增广的原理就不说了,Some Improvements on Deep Convolutional Neural Network Based Image Classification这篇文章里有对我用的方法和作用有简单的介绍。
废话少说,现在开始说重点吧。

方法

本文为为数据使用了翻转、裁剪和添加噪声的操作。

工具

本文中使用的是VOC格式数据的数据增广,因为COCO格式的annotation解析真的太麻烦了,如果你是COCO格式的数据。。。那可能这篇文章对你没啥帮助了,但是如果你是VOC格式的数据,或者你要用coco格式但是也有VOC格式的,可以先对VOC格式数据进行数据增广,然后用以下链接的方法转为COCO格式: VOC2COCO
本次使用到的工具有如下:

# python3
# cv2
# xml
# numpy

代码

在使用的时候直接修改代码前面的地址参数就可以了。
代码如下:

import cv2 as cv
import numpy as np
import os
import xml.etree.ElementTree as ET
import copy
import random

data_path = 'E:/Working/CornerNet-master/data/coco/images/cancer_image'
save_img_path = 'test_imgs'             # 'data/coco/images'
save_anno_path = 'test_annos'            # 'data/coco/annotations'
mask_path = 'new_Annotations'
suffix = '.jpg'


def get_anno(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    if root.tag != 'annotation':
        raise Exception("root should be annotation")

    object_ = []
    img_shape = None
    for elem in root:
        tag = elem.tag
        if tag == 'size':
            img_shape = [int(elem[0].text), int(elem[1].text)]
        if tag == 'object':
            anno_shape = [int(elem[2][0].text), int(elem[2][1].text),
                          int(elem[2][2].text), int(elem[2][3].text)]
            object_.append(anno_shape)
    return img_shape, object_


def flip(img, i_shape, anno_shape):
    '''
    :param img:
    :param i_shape: 列表
    :param d_shape: 包含各个anno列表的list
    :return:
    '''
    d_shape = copy.deepcopy(anno_shape)

    xImg = cv.flip(img, 1, dst=None)
    ret_ishape = i_shape
    arr_long = len(d_shape)
    # print(i_shape,d_shape,arr_long)
    for i in range(arr_long):
        # xmin. ymin, xmax, ymax
        # 注意该标注中以宽为X,高为Y
        # print(i_shape[1], d_shape[i][2])
        temp = d_shape[i][0]
        d_shape[i][0] = i_shape[1] - d_shape[i][2]
        d_shape[i][2] = i_shape[1] - temp
        # print(d_shape[i][0])

    return xImg, ret_ishape, d_shape


def crop(img, i_shape, annos_shape):
    crop_max_long = 0.4
    d_shape = copy.deepcopy(annos_shape)
    # 判断对一边裁剪或者两边裁剪:0左上,1右下
    x_crop_loc = random.randint(0, 2)
    y_crop_loc = random.randint(0, 2)

    if x_crop_loc is 0:
        rand = random.random()
        x_min = 0
        x_max = i_shape[1]*(1-crop_max_long*rand)
    elif x_crop_loc is 1:
        rand = random.random()
        x_min = i_shape[1]*(crop_max_long*rand)
        x_max = i_shape[1]
    else:
        rand1 = random.random()
        rand2 = random.random()
        x_min = i_shape[1]*(crop_max_long*rand1/2)
        x_max = i_shape[1] * (1 - crop_max_long * rand2 / 2)

    if y_crop_loc is 0:
        rand = random.random()
        y_min = 0
        y_max = i_shape[0] * (1 - crop_max_long * rand)
    elif x_crop_loc is 1:
        rand = random.random()
        y_min = i_shape[0] * (crop_max_long * rand)
        y_max = i_shape[0]
    else:
        rand1 = random.random()
        rand2 = random.random()
        y_min = i_shape[0] * (crop_max_long * rand1 / 2)
        y_max = i_shape[0] * (1 - crop_max_long * rand2 / 2)

    x_max = int(x_max)
    y_max = int(y_max)
    x_min = int(x_min)
    y_min = int(y_min)
    new_img_shape = [y_max-y_min, x_max-x_min]
    a_l = len(d_shape)
    for i in range(a_l):
        if x_crop_loc == 1 or x_crop_loc == 2:
            d_shape[i][0] -= x_min                  # xmin
            d_shape[i][2] -= x_min                  # xmax
            if d_shape[i][0] < 0:
                d_shape[i][0] = 0
            elif d_shape[i][0] > new_img_shape[1]:
                d_shape[i][0] = new_img_shape[1] - 1                # 此处直接认为mask溢出则把边界认为是mask是错误的
            if d_shape[i][2] > new_img_shape[1]:                    # 但是推测对图片的影响不大
                d_shape[i][2] = new_img_shape[1]                    # 如果修改了这里需要对Data_Augmentation修改
            elif d_shape[i][2] < 0:                                 # 所以就先试一试 !#
                d_shape[i][2] = 1
        if y_crop_loc == 1 or y_crop_loc == 2:
            d_shape[i][1] -= y_min                  # ymin
            d_shape[i][3] -= y_min                  # ymax
            if d_shape[i][1] < 0:
                d_shape[i][1] = 0
            elif d_shape[i][1] > new_img_shape[0]:
                d_shape[i][1] = new_img_shape[0] - 1
            if d_shape[i][3] > new_img_shape[0]:
                d_shape[i][3] = new_img_shape[0]
            elif d_shape[i][3] < 0:
                d_shape[i][3] = 1
    return copy.deepcopy(img[y_min:y_max, x_min:x_max, :]), new_img_shape, d_shape


def noise(img, i_shape, d_shape):
    # 添加正态分布的高斯噪声
    image = np.array(img, dtype=float)
    noise = np.random.normal(size=(i_shape[0], i_shape[1]))

    # print(i_shape)
    out = np.zeros_like(image)
    for i in range(3):
        out[:, :, i] = image[:, :, i]+noise
    out[out > 255] = 255
    out[out < 0] = 0
    out = out.astype('uint8')
    return out, i_shape, d_shape


# def save(name, img, img_shape, annos_shape):
#     i_name = name+suffix
#     l_name = name+'.xml'
#     save_i_name = os.path.join(save_img_path, i_name)
#     save_a_name = os.path.join(save_anno_path, l_name)

def augmentation(img_path, xml_path):
    imgs = []
    img_shape, annos_shape = get_anno(xml_path)
    img = cv.imread(img_path)
    imgs.append((img, img_shape, annos_shape))
    # flip
    img1, img_shape1, annos_shape1 = flip(img, img_shape, annos_shape)
    imgs.append((img1, img_shape1, annos_shape1))
    # crop
    temp = []
    for i in imgs:
        i_i, i_i_s, i_a_s = i[0], i[1], i[2]
        t_img, t_i_s, t_a_s = crop(i_i, i_i_s, i_a_s)
        temp.append((t_img, t_i_s, t_a_s))
    imgs.extend(temp)
    # noise
    temp = []
    for i in imgs:
        i_i, i_i_s, i_a_s = i[0], i[1], i[2]
        t_img, t_i_s, t_a_s = noise(i_i, i_i_s, i_a_s)
        temp.append((t_img, t_i_s, t_a_s))
    imgs.extend(temp)
    return imgs


if __name__ == "__main__":
    img_dirs = os.listdir(data_path)
    xml_dirs = os.listdir(mask_path)
    
    i = 0                   # i作为保存文件的名字
    for xml_dir in xml_dirs:
        print(i)
        img_dir = os.path.join(data_path, xml_dir.split('.')[0] + suffix)
        xml_dir = os.path.join(mask_path, xml_dir)
        imgs = augmentation(img_dir, xml_dir)
        # print(imgs[0][0])
        for j in range(8):
            image_data = imgs[j]
            name = str(i)
            i += 1
            img_path = os.path.join(save_img_path, name + suffix)
            # print(img_path)
            cv.imwrite(img_path, image_data[0])
            tree = ET.parse(xml_dir)
            root = tree.getroot()
            object_num = 0
            for elem in root:
                tag = elem.tag
                if tag == 'size':
                    elem[0].text, elem[1].text = str(image_data[1][0]), str(image_data[1][1])
                if tag == 'filename':
                    elem.text = name + '.xml'
                if tag == 'object':
                    elem[2][0].text, elem[2][1].text, elem[2][2].text, elem[2][3].text = \
                        str(image_data[2][object_num][0]), str(image_data[2][object_num][1]), \
                        str(image_data[2][object_num][2]), str(image_data[2][object_num][3])
                    object_num += 1
            write_path = os.path.join(save_anno_path, name + '.xml')
            tree.write(write_path)


你可能感兴趣的:(简单的数据增广(Data Augmentation)(附代码))