YoloV5训练图片增强 python代码
Generate More Label – 生成配套的Label
from xml.etree.ElementTree import ElementTree, Element
import xml.etree.ElementTree as ET
import os
from tqdm import tqdm
def change_box(update_path, save_path, filename, box_n, new_value):
update_tree = ET.parse(update_path + filename)
root = update_tree.getroot()
size = root.find('size')
object = root.find('object')
bndbox = object.find("bndbox")
_box_n = bndbox.find(box_n)
_box_n.text = new_value
update_tree.write(save_path + filename, encoding='utf-8')
def change_name(update_path, save_path, filename, new_name):
update_tree = ET.parse(update_path + filename)
root = update_tree.getroot()
size = root.find('size')
object = root.find('object')
filename = root.find("filename")
filename.text = new_name
update_tree.write(save_path + new_name, encoding='utf-8')
def rotate_90(update_path, save_path, filename, new_name):
"""generate rotate 90° new label"""
update_tree = ET.parse(update_path + filename)
root = update_tree.getroot()
size = root.find('size')
object = root.find('object')
bndbox = root.find('object').find("bndbox")
width = int(size.find('width').text)
height = int(size.find('height').text)
xmin, ymin, xmax, ymax = bndbox.find("xmin"), bndbox.find("ymin"), bndbox.find(
"xmax"), bndbox.find("ymax")
_xmin, _ymin, _xmax, _ymax = int(xmin.text), int(ymin.text), int(xmax.text), int(ymax.text)
xmin.text = str(_ymin)
ymin.text = str(width - _xmax)
xmax.text = str(_ymax)
ymax.text = str(height - _xmin)
update_tree.write(save_path + new_name, encoding='utf-8')
def rotate_180(update_path, save_path, filename, new_name):
"""generate rotate 180° new label"""
update_tree = ET.parse(update_path + filename)
root = update_tree.getroot()
size = root.find('size')
bndbox = root.find('object').find("bndbox")
width = int(size.find('width').text)
height = int(size.find('height').text)
xmin, ymin, xmax, ymax = bndbox.find("xmin"), bndbox.find("ymin"), bndbox.find(
"xmax"), bndbox.find("ymax")
_xmin, _ymin, _xmax, _ymax = int(xmin.text), int(ymin.text), int(xmax.text), int(ymax.text)
xmin.text = str(height - _xmax)
ymin.text = str(width - _ymax)
xmax.text = str(height - _xmin)
ymax.text = str(width - _ymin)
update_tree.write(save_path + new_name, encoding='utf-8')
def gen_fli_label(update_path, save_path, filename, new_name):
"""generate fli new label"""
update_tree = ET.parse(update_path + filename)
root = update_tree.getroot()
size = root.find('size')
bndbox = root.find('object').find("bndbox")
width = int(size.find('width').text)
height = int(size.find('height').text)
xmin, ymin, xmax, ymax = bndbox.find("xmin"), bndbox.find("ymin"), bndbox.find("xmax"), bndbox.find("ymax")
_xmin, _ymin, _xmax, _ymax = int(xmin.text), int(ymin.text), int(xmax.text), int(ymax.text)
xmin.text = str(width - _xmax)
ymin.text = str(_ymin)
xmax.text = str(height - _xmin)
ymax.text = str(_ymax)
update_tree.write(save_path + new_name, encoding='utf-8')
def modify_quality(update_path, save_path, filename, new_name):
"""generate rotate 180° new label"""
change_name(update_path, save_path, filename, new_name)
def gen_label_square(update_path, save_path, filename):
update_tree = ET.parse(update_path + filename)
root = update_tree.getroot()
bndbox = root.find('object').find("bndbox")
width = root.find('size').find('width')
height = root.find('size').find('height')
xmin, ymin, xmax, ymax = bndbox.find("xmin"), bndbox.find("ymin"), bndbox.find("xmax"), bndbox.find("ymax")
_xmin, _ymin, _xmax, _ymax = int(xmin.text), int(ymin.text), int(xmax.text), int(ymax.text)
_width, _height = int(width.text), int(height.text)
max_w = max(_width, _height)
width.text, height.text = str(max_w), str(max_w)
if _height > _width:
xmin.text = str(_xmin + (_height - _width) // 2)
xmax.text = str(_xmax + (_height - _width) // 2)
elif _width > _height:
ymin.text = str(_ymin + (_width - _height) // 2)
ymax.text = str(_ymax + (_width - _height) // 2)
update_tree.write(save_path + filename, encoding='utf-8')
def gen_rotate_label(update_path, save_path, filename):
new_name_90 = filename.split('.')[0] + "_r90" + ".xml"
new_name_180 = filename.split('.')[0] + "_r180" + ".xml"
modify_fli = filename.split('.')[0] + "_fli" + ".xml"
rotate_90(update_path, save_path, filename, new_name=new_name_90)
rotate_180(update_path, save_path, filename, new_name=new_name_180)
gen_fli_label(update_path, save_path, filename, new_name=modify_fli)
def gen_modify_label(update_path, save_path, filename):
modify_blur = filename.split('.')[0] + "_blur" + ".xml"
modify_brighter = filename.split('.')[0] + "_brighter" + ".xml"
modify_darker = filename.split('.')[0] + "_darker" + ".xml"
modify_nosie = filename.split('.')[0] + "_nosie" + ".xml"
modify_quality(update_path, save_path, filename, new_name=modify_blur)
modify_quality(update_path, save_path, filename, new_name=modify_brighter)
modify_quality(update_path, save_path, filename, new_name=modify_darker)
modify_quality(update_path, save_path, filename, new_name=modify_nosie)
if __name__ == "__main__":
update_path = "../data/picture/label/"
for label_name in tqdm(os.listdir(update_path)):
update_tree = ET.parse(update_path + label_name)
root = update_tree.getroot()
bndbox = root.find('object').find("bndbox")
width = root.find('size').find('width')
height = root.find('size').find('height')
_width, _height = int(width.text), int(height.text)
if _width != _height:
gen_label_square(update_path, update_path, label_name)
for label_name in tqdm(os.listdir(update_path)):
gen_rotate_label(update_path, update_path, label_name)
for label_name in tqdm(os.listdir(update_path)):
gen_modify_label(update_path, update_path, label_name)
Generate More Picture – 生成增强图片
import cv2
import numpy as np
import os.path
import copy
from tqdm import tqdm
import xml.etree.ElementTree as ET
from PIL import Image
def trans_square(image, save_name):
r"""Open the image using PIL."""
image = image.convert('RGB')
w, h = image.size
background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(255, 255, 255))
length = int(abs(w - h) // 2)
box = (length, 0) if w < h else (0, length)
background.paste(image, box)
print(save_name)
background.save(save_name, quality=95)
return background
def salt_and_pepper(src, percetage):
"""椒盐噪声"""
SP_NoiseImg = src.copy()
SP_NoiseNum = int(percetage * src.shape[0] * src.shape[1])
for i in range(SP_NoiseNum):
randR = np.random.randint(0, src.shape[0] - 1)
randG = np.random.randint(0, src.shape[1] - 1)
randB = np.random.randint(0, 3)
if np.random.randint(0, 1) == 0:
SP_NoiseImg[randR, randG, randB] = 0
else:
SP_NoiseImg[randR, randG, randB] = 255
return SP_NoiseImg
def addGaussianNoise(image, percetage):
"""给图片增加高斯噪声"""
G_Noiseimg = image.copy()
w = image.shape[1]
h = image.shape[0]
G_NoiseNum = int(percetage * image.shape[0] * image.shape[1])
for i in range(G_NoiseNum):
temp_x = np.random.randint(0, h)
temp_y = np.random.randint(0, w)
G_Noiseimg[temp_x][temp_y][np.random.randint(3)] = np.random.randn(1)[0]
return G_Noiseimg
def darker(image, percetage=0.9):
"""降低图片亮度"""
image_copy = image.copy()
w = image.shape[1]
h = image.shape[0]
for xi in range(0, w):
for xj in range(0, h):
image_copy[xj, xi, 0] = int(image[xj, xi, 0] * percetage)
image_copy[xj, xi, 1] = int(image[xj, xi, 1] * percetage)
image_copy[xj, xi, 2] = int(image[xj, xi, 2] * percetage)
return image_copy
def brighter(image, percetage=1.5):
"""增加图片亮度"""
image_copy = image.copy()
w = image.shape[1]
h = image.shape[0]
for xi in range(0, w):
for xj in range(0, h):
image_copy[xj, xi, 0] = np.clip(int(image[xj, xi, 0] * percetage), a_max=255, a_min=0)
image_copy[xj, xi, 1] = np.clip(int(image[xj, xi, 1] * percetage), a_max=255, a_min=0)
image_copy[xj, xi, 2] = np.clip(int(image[xj, xi, 2] * percetage), a_max=255, a_min=0)
return image_copy
def rotate(image, angle, center=None, scale=1.0):
"""生成旋转图片"""
(h, w) = image.shape[:2]
if center is None:
center = (w / 2, h / 2)
m = cv2.getRotationMatrix2D(center, angle, scale)
rotated = cv2.warpAffine(image, m, (w, h))
return rotated
def flip(image):
"""生成翻转图片"""
flipped_image = np.fliplr(image)
return flipped_image
def run(file_dir):
for img_name in tqdm(os.listdir(file_dir)):
image = Image.open(file_dir + img_name)
if image.size[0] != image.size[1]:
trans_square(image, file_dir + img_name)
for img_name in tqdm(os.listdir(file_dir)):
img_path = file_dir + img_name
img = cv2.imread(img_path)
rotated_90 = rotate(img, 90)
cv2.imwrite(file_dir + img_name[0:-4] + '_r90.jpg', rotated_90)
rotated_180 = rotate(img, 180)
cv2.imwrite(file_dir + img_name[0:-4] + '_r180.jpg', rotated_180)
flipped_img = flip(img)
cv2.imwrite(file_dir + img_name[0:-4] + '_fli.jpg', flipped_img)
for img_name in tqdm(os.listdir(file_dir)):
img_path = file_dir + img_name
img = cv2.imread(img_path)
img_gauss = addGaussianNoise(img, 0.3)
cv2.imwrite(file_dir + img_name[0:-4] + '_noise.jpg', img_gauss)
img_darker = darker(img)
cv2.imwrite(file_dir + img_name[0:-4] + '_darker.jpg', img_darker)
img_brighter = brighter(img)
cv2.imwrite(file_dir + img_name[0:-4] + '_brighter.jpg', img_brighter)
blur = cv2.GaussianBlur(img, (7, 7), 1.5)
cv2.imwrite(file_dir + img_name[0:-4] + '_blur.jpg', blur)
if __name__ == '__main__':
run(file_dir='../data/picture/img/')