目标检测任务中对四点坐标类的数据标签与图片做增强

网上有比较多的目标检测数据集增强的方法,但是其数据的格式在txt中是

x_min,y_min,x_max,y_max,class

然后我现在做的任务是检测不规则的四边形,数据标签是

x1,y1,x2,y2,x3,y3,x4,y4,class

所以重新写一个增强的程序,这里记录一下,下次用起来就方面多了

import numpy as np
import cv2
import math
from PIL import Image, ImageDraw
from skimage.util import random_noise
from skimage import exposure
import random
import os
def rotate_box(img,pts,maxangle,scale=1):
    w = img.shape[1]
    h = img.shape[0]
    angle=random.uniform(0, maxangle)
    rangle = np.deg2rad(angle)
    nw = (abs(np.sin(rangle)*h) + abs(np.cos(rangle)*w))*scale
    nh = (abs(np.cos(rangle)*h) + abs(np.sin(rangle)*w))*scale
    rot_mat = cv2.getRotationMatrix2D((nw*0.5, nh*0.5), angle, scale)
    rot_move = np.dot(rot_mat, np.array([(nw-w)*0.5, (nh-h)*0.5,0]))
    rot_mat[0,2] += rot_move[0]
    rot_mat[1,2] += rot_move[1]
    rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
    all_rot_po=[]

    for point_list in pts:
        rot_point_list=[]
        for pt in point_list:
            point = np.dot(rot_mat, np.array([pt[0],pt[1],1]))
            rot_point_list.append(point.tolist())
        all_rot_po.append(np.array(rot_point_list))
    return rot_img,all_rot_po


def addNoise(img):#添加噪点
    return random_noise(img, mode='gaussian', clip=True)*255

def changeLight(img):
    flag = random.uniform(0.5, 1.5)
    return exposure.adjust_gamma(img, flag)

def shift_pic_bboxes(img,pts):
    x_min=img.shape[1]
    y_min=img.shape[0]
    x_max,y_max=0,0
    for point_list in pts:
        for pt in point_list:
            x_min=min(x_min,pt[0])
            y_min=min(y_min,pt[1])
            x_max=max(x_max,pt[0])
            y_max=max(y_max,pt[1])
    d_to_left=x_min
    d_to_right=img.shape[1]-x_max
    d_to_top=y_min
    d_to_bottom=img.shape[0]-y_max
    x = random.uniform(-(d_to_left-1) / 3, (d_to_right-1) / 3)
    y = random.uniform(-(d_to_top-1) / 3, (d_to_bottom-1) / 3)
    M = np.float32([[1, 0, x], [0, 1, y]])
    shift_img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
    all_rot_po=[]
    for point_list in pts:
        rot_point_list=[]
        for pt in point_list:
            point = pt+np.array([x,y])
            rot_point_list.append(point.tolist())
        all_rot_po.append(np.array(rot_point_list))

    return shift_img,all_rot_po
def flip_img(img,pts):
    if random.random()>0.5:
        horizon = True
    else:
        horizon=False
    if horizon:
        flip_img =  cv2.flip(img, 1)
    else:
        flip_img = cv2.flip(img,0)
    all_rot_po=[]
    for point_list in pts:
        rot_point_list=[]
        for pt in point_list:
            if horizon:
                pt[0]=img.shape[1]-pt[0]
            else:
                pt[1]=img.shape[0]-pt[1]
            rot_point_list.append(pt.tolist())
        all_rot_po.append(np.array(rot_point_list))
    return flip_img,all_rot_po

def cut_area(img):
    size=min(img.shape[0],img.shape[1])//10

    x=random.randint(0,img.shape[1]-size-1)
    y=random.randint(0,img.shape[0]-size-1)
    cv2.rectangle(img,(x,y),(x+size,y+size),(0,0,0),-1)
    return img

txt_path='/home/user/AdvancedEAST/train_card/origin_txt/'
img_path='/home/user/AdvancedEAST/train_card/origin_image/'
save_txt='/home/user/AdvancedEAST/train_card_aug/origin_txt/'
save_img='/home/user/AdvancedEAST/train_card_aug/origin_img/'
aug_number=5
max_angle=90
img_list=os.listdir(img_path)
j=0
for file_name in img_list:
    j+=1
    img_dir=img_path+file_name
    txt_dir=txt_path+file_name[:-4]+'.txt'


    origin_img=cv2.imread(img_dir)
    txt_list=open(txt_dir).readlines()
    origin_pts=[]#origin_pts为【-1,4,2】
    for tl in txt_list:
        tl=tl.strip().split(',')
        tl=np.array(tl)
        origin_pts.append(np.reshape(tl[:-1].astype(float),(4,2)))

    for i in range(aug_number):
        img=origin_img
        pts=origin_pts
        changenum=0
        while changenum<1:
            if random.random()>0.5 :
                img=changeLight(img)
                changenum+=1

            if random.random()>0.5 :
                img=addNoise(img)
                changenum+=1

            if random.random()>0.5 :
                img=cut_area(img)
                changenum+=1

            if random.random()>0.5 :
                img,pts=rotate_box(img,pts,max_angle)#这里的10是转动最大的角度
                changenum+=1

            if random.random()>0.5 :
                img,pts=shift_pic_bboxes(img,pts)
                changenum+=1
            if random.random()>0.5 :
                img,pts=flip_img(img,pts)
                changenum+=1




        cv2.imwrite(save_img+'{}_{}.jpg'.format(file_name[:-4],i),img)
        with open(save_txt+'{}_{}.txt'.format(file_name[:-4],i),'w') as f:
            txt_str=''
            for point_list in pts:
                for pt in point_list:        
                    txt_str=txt_str+str(pt[0])+','+str(pt[1])+','
            # all_rot_po.append(np.array(rot_point_list))
                txt_str=txt_str+'unknown\n'
            f.write(txt_str)
    print('{}/{}'.format(j,len(img_list)))

 

你可能感兴趣的:(python,深度学习)