网上有比较多的目标检测数据集增强的方法,但是其数据的格式在txt中是
x_min,y_min,x_max,y_max,class
然后我现在做的任务是检测不规则的四边形,数据标签是
x1,y1,x2,y2,x3,y3,x4,y4,class
所以重新写一个增强的程序,这里记录一下,下次用起来就方面多了
import numpy as np
import cv2
import math
from PIL import Image, ImageDraw
from skimage.util import random_noise
from skimage import exposure
import random
import os
def rotate_box(img,pts,maxangle,scale=1):
w = img.shape[1]
h = img.shape[0]
angle=random.uniform(0, maxangle)
rangle = np.deg2rad(angle)
nw = (abs(np.sin(rangle)*h) + abs(np.cos(rangle)*w))*scale
nh = (abs(np.cos(rangle)*h) + abs(np.sin(rangle)*w))*scale
rot_mat = cv2.getRotationMatrix2D((nw*0.5, nh*0.5), angle, scale)
rot_move = np.dot(rot_mat, np.array([(nw-w)*0.5, (nh-h)*0.5,0]))
rot_mat[0,2] += rot_move[0]
rot_mat[1,2] += rot_move[1]
rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
all_rot_po=[]
for point_list in pts:
rot_point_list=[]
for pt in point_list:
point = np.dot(rot_mat, np.array([pt[0],pt[1],1]))
rot_point_list.append(point.tolist())
all_rot_po.append(np.array(rot_point_list))
return rot_img,all_rot_po
def addNoise(img):#添加噪点
return random_noise(img, mode='gaussian', clip=True)*255
def changeLight(img):
flag = random.uniform(0.5, 1.5)
return exposure.adjust_gamma(img, flag)
def shift_pic_bboxes(img,pts):
x_min=img.shape[1]
y_min=img.shape[0]
x_max,y_max=0,0
for point_list in pts:
for pt in point_list:
x_min=min(x_min,pt[0])
y_min=min(y_min,pt[1])
x_max=max(x_max,pt[0])
y_max=max(y_max,pt[1])
d_to_left=x_min
d_to_right=img.shape[1]-x_max
d_to_top=y_min
d_to_bottom=img.shape[0]-y_max
x = random.uniform(-(d_to_left-1) / 3, (d_to_right-1) / 3)
y = random.uniform(-(d_to_top-1) / 3, (d_to_bottom-1) / 3)
M = np.float32([[1, 0, x], [0, 1, y]])
shift_img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
all_rot_po=[]
for point_list in pts:
rot_point_list=[]
for pt in point_list:
point = pt+np.array([x,y])
rot_point_list.append(point.tolist())
all_rot_po.append(np.array(rot_point_list))
return shift_img,all_rot_po
def flip_img(img,pts):
if random.random()>0.5:
horizon = True
else:
horizon=False
if horizon:
flip_img = cv2.flip(img, 1)
else:
flip_img = cv2.flip(img,0)
all_rot_po=[]
for point_list in pts:
rot_point_list=[]
for pt in point_list:
if horizon:
pt[0]=img.shape[1]-pt[0]
else:
pt[1]=img.shape[0]-pt[1]
rot_point_list.append(pt.tolist())
all_rot_po.append(np.array(rot_point_list))
return flip_img,all_rot_po
def cut_area(img):
size=min(img.shape[0],img.shape[1])//10
x=random.randint(0,img.shape[1]-size-1)
y=random.randint(0,img.shape[0]-size-1)
cv2.rectangle(img,(x,y),(x+size,y+size),(0,0,0),-1)
return img
txt_path='/home/user/AdvancedEAST/train_card/origin_txt/'
img_path='/home/user/AdvancedEAST/train_card/origin_image/'
save_txt='/home/user/AdvancedEAST/train_card_aug/origin_txt/'
save_img='/home/user/AdvancedEAST/train_card_aug/origin_img/'
aug_number=5
max_angle=90
img_list=os.listdir(img_path)
j=0
for file_name in img_list:
j+=1
img_dir=img_path+file_name
txt_dir=txt_path+file_name[:-4]+'.txt'
origin_img=cv2.imread(img_dir)
txt_list=open(txt_dir).readlines()
origin_pts=[]#origin_pts为【-1,4,2】
for tl in txt_list:
tl=tl.strip().split(',')
tl=np.array(tl)
origin_pts.append(np.reshape(tl[:-1].astype(float),(4,2)))
for i in range(aug_number):
img=origin_img
pts=origin_pts
changenum=0
while changenum<1:
if random.random()>0.5 :
img=changeLight(img)
changenum+=1
if random.random()>0.5 :
img=addNoise(img)
changenum+=1
if random.random()>0.5 :
img=cut_area(img)
changenum+=1
if random.random()>0.5 :
img,pts=rotate_box(img,pts,max_angle)#这里的10是转动最大的角度
changenum+=1
if random.random()>0.5 :
img,pts=shift_pic_bboxes(img,pts)
changenum+=1
if random.random()>0.5 :
img,pts=flip_img(img,pts)
changenum+=1
cv2.imwrite(save_img+'{}_{}.jpg'.format(file_name[:-4],i),img)
with open(save_txt+'{}_{}.txt'.format(file_name[:-4],i),'w') as f:
txt_str=''
for point_list in pts:
for pt in point_list:
txt_str=txt_str+str(pt[0])+','+str(pt[1])+','
# all_rot_po.append(np.array(rot_point_list))
txt_str=txt_str+'unknown\n'
f.write(txt_str)
print('{}/{}'.format(j,len(img_list)))