目标检测中填鸭式数据增强方法(利用voc数据格式)

目标检测中的填鸭式数据增强方法(利用voc数据格式)

注意:这里的数据增强方式代码建立在voc数据格式之上

填鸭式对小目标的检测应该效果不错,比如一张图片里目标较少,又比较小,可能网络并不能有效注意到这样的目标。通过填鸭式数据增强方式,即复制此目标至图片多处位置,增加其在图片上的数量来加强训练。

import random
import numpy as np
import xml.dom.minidom
import matplotlib.pyplot as plt
from PIL import Image,ImageDraw
from os import getcwd

def bbox_iou(box1, box2):
    b1_x1, b1_y1, b1_x2, b1_y2 = box1
    b2_x1, b2_y1, b2_x2, b2_y2 = box2
    #get the corrdinates of the intersection rectangle
    inter_rect_x1 =  max(b1_x1, b2_x1)
    inter_rect_y1 =  max(b1_y1, b2_y1)
    inter_rect_x2 =  min(b1_x2, b2_x2)
    inter_rect_y2 =  min(b1_y2, b2_y2)
    #Intersection area
    inter_width = inter_rect_x2 - inter_rect_x1 + 1
    inter_height = inter_rect_y2 - inter_rect_y1 + 1
    if inter_width > 0 and inter_height > 0:#strong condition
        inter_area = inter_width * inter_height
        #Union Area
        b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1)
        b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1)
        iou = inter_area / (b1_area + b2_area - inter_area)
    else:
        iou = 0
    return iou

def aug_data_method(wd,root_path,img_name,row):
    img = Image.open(root_path + '/' +img_name + ".jpg")
    lines_ = []
    #见注1
    
    with open("train" + '.txt', 'r') as f:
        lines = f.readlines()
    for line in lines:
        line = line.split()
        lines_.append(line)

    bboxes = []
    lines_eachimg = []
    for i in range(len(lines_[row]) - 1):
        line = lines_[row][i + 1].split(',')
        lines_eachimg.append(line)
        bboxes.append([int(line[0]), int(line[1]), int(line[2]), int(line[3])])
    list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s_.jpg' % (wd, year, image_id))
    for i in range(len(bboxes)):
        b = (int(bboxes[i][0]), int(bboxes[i][1]),
             int(bboxes[i][2]), int(bboxes[i][3]))
        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))

    specific_idxs = [0]
    threshold = 0.3
    sample_num_per_sample = 2

    for line in lines_eachimg:
        clsname = int(line[4])
        if clsname in specific_idxs:
            bbox_left, bbox_top, bbox_right, bbox_down = int(line[0]), int(line[1]), int(line[2]), int(line[3])
            for i in range(sample_num_per_sample):
                new_bbox_left = random.randint(0, width - bbox_right + bbox_left)
                new_bbox_top = random.randint(0, height - bbox_down + bbox_top)
                bbox1 = [new_bbox_left, new_bbox_top, new_bbox_left + bbox_right - bbox_left, new_bbox_top + bbox_down - bbox_top]
                ious = [bbox_iou(bbox1, bbox) for bbox in bboxes]
                if max(ious) <= threshold:
                    bboxes.append(bbox1)
                    cls_id = 0#看着改
                    b = (int(bbox1[0]), int(bbox1[1]),
                         int(bbox1[2]), int(bbox1[3]))
                    list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
  
                    region = img.crop((bbox_left, bbox_top, bbox_right, bbox_down))
                    img.paste(region, (bbox1[0], bbox1[1]))
    img.save(root_path + '/' + img_name + "_.jpg")
    list_file.write('\n')
    print(row)
    return
if __name__ == '__main__':
    sets=[('2012', 'train')]
    for year, image_set in sets:
        image_ids = open('./VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
        list_file = open('./%s_augdata.txt'%image_set, 'w')
        for image_index, image_id in enumerate(image_ids):
            aug_data_method(wd="/your/project/path",root_path="/VOCdevkit/VOC2012/JPEGImages",
                            img_name=image_id,row=image_index)
        list_file.close()

注1:关于此处读取的train.txt如图所示,就是网络读取的voc格式数据集转化的txt文件,用来得到真实框和类别在这里插入图片描述
代码细节需要自行修改如路径、类别等的细节再使用(此处我只有一个类别),此代码提供大致思路

你可能感兴趣的:(数据增强,深度学习)