由于本人在使用yolov3模型跑训练时发现自己标注的图像样本数量实在太少,又懒得再去标注图像,因为标注图像真的是一件超级考验眼睛的事情::>_<:: 故写下了这个数据扩充脚本,将已经标注好的图片进行随机变换,并将该图片对应的txt文件利用线性代数的方法做相应改变,这样标注一张图片就相当于有了好几张图片啦~
代码中有关图像旋转原理详解的部分参考了下面链接中的内容:
https://blog.csdn.net/liyuan02/article/details/6750828
首先直接贴出源码
# -*- coding:utf-8 -*-
import numpy as np
import os
import cv2
import math
import random
start_path = './'
out_path = './out/'
#将所有图片存入列表
def get_filelist():
imgs_list = []
imgs_file = os.listdir(start_path)
for f in imgs_file:
if f.endswith('.jpg'):
imgs_list.append(f)
return imgs_list
#水平镜像翻转
def img_her_flip(pic, labels):
labels = labels
#img_width, img_height, c = pic.shape
flip_pic = cv2.flip(pic, 1, dst=None) # 水平镜像
liner_trans = np.array([[-1, 0, 1],
[0, 1, 0],
[0, 0, 1]])
for i, label in enumerate(labels):
cla, x, y, width, height = label
cen_row_vec = np.array([[float(x), float(y), 1]])
cen_vec = cen_row_vec.T
new_vec = np.dot(liner_trans, cen_vec)
new_x = round(float(new_vec[0][0]), 6)
new_y = round(float(new_vec[1][0]), 6)
labels[i] = [cla, new_x, new_y, width, height]
return flip_pic, labels
#垂直镜像翻转
def img_ver_flip(pic, labels):
labels = labels
flip_pic = cv2.flip(pic,-1,dst=None) # 垂直镜像
liner_trans = np.array([[-1, 0, 1],
[0, -1, 1],
[0, 0, 1]])
for i, label in enumerate(labels):
cla, x, y, width, height = label
cen_row_vec = np.array([[float(x), float(y), 1]])
cen_vec = cen_row_vec.T
new_vec = np.dot(liner_trans, cen_vec)
new_x = round(float(new_vec[0][0]), 6)
new_y = round(float(new_vec[1][0]), 6)
labels[i] = [cla, new_x, new_y, width, height]
return flip_pic, labels
#模糊
def img_blur(pic, labels):
labels = labels
blur_pic = cv2.blur(pic, (5, 5))
return blur_pic, labels
#添加噪音
def add_noise(pic, labels):
noise_pic = pic
labels = labels
for i in range(1000):
temp_x = np.random.randint(0, noise_pic.shape[0])
temp_y = np.random.randint(0, noise_pic.shape[1])
noise_pic[temp_x][temp_y] = 255
return noise_pic, labels
#HSV色彩变换
def img_to_hsv(pic, labels):
labels = labels
hsv_pic = cv2.cvtColor(pic, cv2.COLOR_BGR2HSV)
return hsv_pic, labels
#图像旋转 若beta > 0 表示逆时针旋转,< 0则表示顺时针旋转
def rotate(pic, labels, beta=-90):
rotate_pic = pic
labels = labels
(h, w) = rotate_pic.shape[:2]
(cX, cY) = (w // 2, h // 2)
M = cv2.getRotationMatrix2D((cX, cY), beta, 1.0)
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
nW = int((h * sin) + (w * cos))
nH = int((h * cos) + (w * sin))
M[0, 2] += (nW / 2) - cX
M[1, 2] += (nH / 2) - cY
rotate_pic = cv2.warpAffine(rotate_pic, M, (nW, nH))
liner_trans = np.array([[math.cos(math.radians(beta)), math.sin(math.radians(beta)), -0.5*math.cos(math.radians(beta))-0.5*math.sin(math.radians(beta))+0.5],
[-math.sin(math.radians(beta)), math.cos(math.radians(beta)), 0.5*math.sin(math.radians(beta))-0.5*math.cos(math.radians(beta))+0.5],
[0, 0, 1]])
for i, label in enumerate(labels):
cla, x, y, width, height = label
cen_row_vec = np.array([[float(x), float(y), 1]])
cen_vec = cen_row_vec.T
new_cen = np.dot(liner_trans, cen_vec)
new_x = round(float(new_cen[0][0]), 6)
new_y = round(float(new_cen[1][0]), 6)
if beta == -90 or beta == 270 or beta == 90 or beta == -270:
new_h = width
new_w = height
else:
new_w = width
new_h = height
labels[i] = [cla, new_x, new_y, new_w, new_h]
return rotate_pic, labels
#图像缩放 若factor > 1表示放大, < 1表示缩小
def zoom(pic, labels, factor=1):
labels = labels
img_width, img_height, c = pic.shape
zoom_pic = cv2.resize(pic, (factor*img_width, factor*img_height))
liner_trans = np.array([[factor, 0, 0],
[0, factor, 0],
[0, 0, 1]])
for i, label in enumerate(labels):
cla, x, y, width, height = label
cen_row_vec = np.array([[float(x), float(y), 1]])
cen_vec = cen_row_vec.T
new_vec = np.dot(liner_trans, cen_vec)
new_x = round(float(new_vec[0][0]), 6)
new_y = round(float(new_vec[1][0]), 6)
labels[i] = [cla, new_x, new_y, factor*width, factor*height]
return zoom_pic, labels
#对图片进行随机处理
def img_random_handle(pic, labels, frequency):
fre = frequency
func1_list = [rotate, add_noise, zoom, img_to_hsv, img_ver_flip, img_her_flip, img_blur]
if fre == 1:
pic, labels = random.choice(func1_list)(pic, labels)
elif fre == 2:
func1_choice = random.sample(func1_list, 2)
pic, labels = func1_choice[0](pic, labels)
pic, labels = func1_choice[1](pic, labels)
elif fre == 3:
func1_choice = random.sample(func1_list, 3)
pic, labels = func1_choice[0](pic, labels)
pic, labels = func1_choice[1](pic, labels)
pic, labels = func1_choice[2](pic, labels)
elif fre == 4:
func1_choice = random.sample(func1_list, 4)
pic, labels = func1_choice[0](pic, labels)
pic, labels = func1_choice[1](pic, labels)
pic, labels = func1_choice[2](pic, labels)
pic, labels = func1_choice[3](pic, labels)
else:
print("变换次数透支~")
return pic, labels
#对图片进行变换并保存
def imgs_expand(frequency):
imgs_list= get_filelist()
change_count = 0
for img in imgs_list:
img_path = start_path + img
pic = cv2.imread(img_path)
img_name = img.split('.')[0]
txt_path = start_path + img_name + '.txt'
with open(txt_path, 'r') as txt_file:
labels = []
for label in txt_file.readlines():
c, x, y, w, h = label.rstrip('\n').split(' ')
labels.append([int(c), float(x), float(y), float(w), float(h)])
txt_file.close()
img, sub_labels = img_random_handle(pic, labels, frequency)
img_save_path = out_path + 'img' + str(change_count) + '.jpg'
cv2.imwrite(img_save_path, img)
print(img_save_path)
txt_save_path = out_path + 'img' + str(change_count) + '.txt'
save_labels = []
for label in sub_labels:
c, x, y, w, h = label
split = ' '
save_labels.append(split.join([str(c), str(x), str(y), str(w), str(h)]) + '\n')
with open(txt_save_path, 'w+') as txt:
txt.writelines(save_labels)
txt.close()
print(txt_save_path)
change_count += 1
print("共修改了%d张图片" % change_count)
print("共修改了%d个文件" % change_count)
if __name__ == '__main__':
imgs_expand(2) #frequency为需要对图片做多少次变换的次数 本脚本中设置最大为4,可以在img_random_handle()函数中修改
在这个脚本中,笔者将每个功能都封装为独立的函数,在每个函数中都对该图片和该图片对应的txt文件做对应的修改,这里的txt文件修改主要采用矩阵相乘计算坐标的方法,这么做的好处是不论图片是什么形状的都可以直接借助线性变换的矩阵来计算变换后的坐标(这里有涉及到线性代数的一些知识,不过不用担心,只是最基础的矩阵运算~),计算完成后返回对应的图片和相应的txt文件内容并直接保存。
主函数需要一个参数即对一张图片做多少种变换,接着读取图片的信息并存入数组,调用对图片随机处理的函数,该函数实现将所有功能函数存入数组,随机调用某一个功能函数,并返回对应的图片和该图片对应的txt中的信息,每次处理一张图片,更改对应的txt文件并保存。
若是后续还需要添加新的功能函数,则只需要先实现独立的功能函数,然后在随机处理函数中的函数列表中添加对应的函数即可。
若有错误之处烦请大家指正~