yolo数据增强以及批量修改图片和xml名

记录下打完标签对数据集进行扩增,数据增强后的图片及标签名字进行修改重点在代码只需更改文件名就可使用

无论数据增强还是修改名称,标签框位置都会跟着改变!!!

前人之鉴,最好还是数据增强后再去打标签,千万千万千万不要图省事

一、数据增强

我使用的是Albumentations扩增数据集,增强方法请移步该博客数据增强方法

如果不知道增强效果,可参考官网给的模拟器模拟增强

import albumentations as A
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2

import albumentations as A
import xml.etree.ElementTree as ET

# 定义类
class VOCAug(object):

    def __init__(self,
                 pre_image_path=None,
                 pre_xml_path=None,
                 aug_image_save_path=None,
                 aug_xml_save_path=None,
                 start_aug_id=None,
                 labels=None,
                 max_len=4,     # 修改数值可以改变名字 1-1, 2-01, 3-001, 4-0001
                 is_show=False):
        """

        :param pre_image_path:
        :param pre_xml_path:
        :param aug_image_save_path:
        :param aug_xml_save_path:
        :param start_aug_id:
        :param labels: 标签列表, 展示增强后的图片用
        :param max_len:
        :param is_show:
        """
        self.pre_image_path = pre_image_path
        self.pre_xml_path = pre_xml_path
        self.aug_image_save_path = aug_image_save_path
        self.aug_xml_save_path = aug_xml_save_path
        self.start_aug_id = start_aug_id
        self.labels = labels
        self.max_len = max_len
        self.is_show = is_show

        print(self.labels)
        assert self.labels is not None, "labels is None!!!"

        # 数据增强选项
        # 数据增强选项
        self.aug = A.Compose([
            # A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5), # 随机亮度对比度
            # A.RandomBrightness(limit=0.3, p=0.5),
            # A.GaussianBlur(p=0.7), # 高斯模糊
            # A.GaussNoise(var_limit=(400, 450),mean=0,p=1),  # 高斯噪声
            # A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.5),  # 直方图均衡
            # A.Equalize(p=0.5),  # 均衡图像直方图
            #  A.Rotate(limit=90, interpolation=0, border_mode=0, p=1),    # 旋转
            A.RandomRotate90(p=1),

            # A.CoarseDropout(p=0.5),  # 随机生成矩阵黑框
            # A.OneOf([
            #     # A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5), #RGB图像的每个通道随机移动值
            #     # A.ChannelShuffle(p=0.3),  # 随机排列通道
            #     # A.ColorJitter(p=0.3),  # 随机改变图像的亮度、对比度、饱和度、色调
            #     # A.ChannelDropout(p=0.3),  # 随机丢弃通道
            # ], p=0.5),
            # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量
            # A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加
        ],
            # voc: [xmin, ymin, xmax, ymax]  # 经过归一化
            # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.
            # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox
            A.BboxParams(format='pascal_voc', min_area=0., min_visibility=0., label_fields=['category_id'])
        )
        print('--------------*--------------')
        print("labels: ", self.labels)
        if self.start_aug_id is None:
            self.start_aug_id = len(os.listdir(self.pre_xml_path)) +1
            print("the start_aug_id is not set, default: len(images)", self.start_aug_id)
        print('--------------*--------------')

    def get_xml_data(self, xml_filename):
        with open(os.path.join(self.pre_xml_path, xml_filename), 'r') as f:
            tree = ET.parse(f)
            root = tree.getroot()
            image_name = tree.find('filename').text
            size = root.find('size')
            w = int(size.find('width').text)
            h = int(size.find('height').text)
            bboxes = []
            cls_id_list = []
            for obj in root.iter('object'):
                # difficult = obj.find('difficult').text
                difficult = obj.find('difficult').text
                cls_name = obj.find('name').text  # label
                if cls_name not in LABELS or int(difficult) == 1:
                    continue
                xml_box = obj.find('bndbox')

                xmin = int(xml_box.find('xmin').text)
                ymin = int(xml_box.find('ymin').text)
                xmax = int(xml_box.find('xmax').text)
                ymax = int(xml_box.find('ymax').text)

                # 标注越界修正
                if xmax > w:
                    xmax = w
                if ymax > h:
                    ymax = h
                bbox = [xmin, ymin, xmax, ymax]
                bboxes.append(bbox)
                cls_id_list.append(self.labels.index(cls_name))

            # 读取图片
            image = cv2.imread(os.path.join(self.pre_image_path, image_name))

        return bboxes, cls_id_list, image, image_name

    def aug_image(self):
        xml_list = os.listdir(self.pre_xml_path)

        cnt = self.start_aug_id
        for xml in xml_list:
            file_suffix = xml.split('.')[-1]
            if file_suffix not in ['xml']:
                continue

            bboxes, cls_id_list, image, image_name = self.get_xml_data(xml)

            anno_dict = {'image': image, 'bboxes': bboxes, 'category_id': cls_id_list}
            # 获得增强后的数据 {"image", "bboxes", "category_id"}
            augmented = self.aug(**anno_dict)

            # 保存增强后的数据
            flag = self.save_aug_data(augmented, image_name, cnt)

            if flag:
                cnt += 1
            else:
                continue

    def save_aug_data(self, augmented, image_name, cnt):
        aug_image = augmented['image']
        aug_bboxes = augmented['bboxes']
        aug_category_id = augmented['category_id']
        # print(aug_bboxes)
        # print(aug_category_id)

        name = '0' * self.max_len
        # 获取图片的后缀名
        image_suffix = image_name.split(".")[-1]

        # 未增强对应的xml文件名
        pre_xml_name = image_name.replace(image_suffix, 'xml')

        # 获取新的增强图像的文件名
        cnt_str = str(cnt)
        length = len(cnt_str)
        new_image_name = name[:-length] + cnt_str + "." + image_suffix

        # 获取新的增强xml文本的文件名
        new_xml_name = new_image_name.replace(image_suffix, 'xml')

        # 获取增强后的图片新的宽和高
        new_image_height, new_image_width = aug_image.shape[:2]

        # 深拷贝图片
        aug_image_copy = aug_image.copy()

        # 在对应的原始xml上进行修改, 获得增强后的xml文本
        with open(os.path.join(self.pre_xml_path, pre_xml_name), 'r') as pre_xml:
            aug_tree = ET.parse(pre_xml)

        # 修改image_filename值
        root = aug_tree.getroot()
        aug_tree.find('filename').text = new_image_name

        # 修改变换后的图片大小
        size = root.find('size')
        size.find('width').text = str(new_image_width)
        size.find('height').text = str(new_image_height)

        # 修改每一个标注框
        for index, obj in enumerate(root.iter('object')):
            obj.find('name').text = self.labels[aug_category_id[index]]
            xmin, ymin, xmax, ymax = aug_bboxes[index]
            xml_box = obj.find('bndbox')
            xml_box.find('xmin').text = str(int(xmin))
            xml_box.find('ymin').text = str(int(ymin))
            xml_box.find('xmax').text = str(int(xmax))
            xml_box.find('ymax').text = str(int(ymax))
            if self.is_show:
                tl = 2
                text = f"{LABELS[aug_category_id[index]]}"
                t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tl)[0]
                cv2.rectangle(aug_image, (int(xmin), int(ymin) - 3),
                              (int(xmin) + t_size[0], int(ymin) - t_size[1] - 3),
                              (0, 0, 255), -1, cv2.LINE_AA)  # filled
                cv2.putText(aug_image, text, (int(xmin), int(ymin) - 2), 0, tl / 3, (255, 255, 255), tl,
                            cv2.LINE_AA)
                cv2.rectangle(aug_image, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 0), 2)

        if self.is_show:
            cv2.imshow('aug_image_show', aug_image_copy)
            # 按下s键保存增强,否则取消保存此次增强
            key = cv2.waitKey(0)
            if key & 0xff == ord('s'):
                pass
            else:
                return False
        # 保存增强后的图片
        cv2.imwrite(os.path.join(self.aug_image_save_path, new_image_name), aug_image)
        # 保存增强后的xml文件
        tree = ET.ElementTree(root)
        tree.write(os.path.join(self.aug_xml_save_path, new_xml_name))

        return True

# 原始的xml路径和图片路径
PRE_IMAGE_PATH = r'E:\ML-data\VOC\images'
PRE_XML_PATH = r'E:\ML-data\VOC\labels'

# 增强后保存的xml路径和图片路径
AUG_SAVE_IMAGE_PATH ='E:\ML-data\VOC\images-aug'
AUG_SAVE_XML_PATH = 'E:\ML-data\VOC\labels-aug'

# 标签列表
LABELS = ['buds']

aug = VOCAug(
    pre_image_path=PRE_IMAGE_PATH,
    pre_xml_path=PRE_XML_PATH,
    aug_image_save_path=AUG_SAVE_IMAGE_PATH,
    aug_xml_save_path=AUG_SAVE_XML_PATH,
    start_aug_id=None,
    labels=LABELS,
    is_show=False,
)

aug.aug_image()

# cv2.destroyAllWindows()

只需要在数据增强选项中添加想要的增强方法,其余的并不建议修改

注:请重点关注下max_len参数,数值的大小决定增强后的图片和标签名字

 二、批量修改图片和xml名称

import os
import shutil


def rename_file(path, new_path, xml_path, new_xml):
    # 打开源文件图像
    file = os.listdir(path)
    for i in range(len(file)):
        # 获得图像扩展名
        (name, extent) = os.path.splitext(file[i])
        # 获得图像对应的xml文件
        xml_file = os.path.join(xml_path, name + '.xml')

        # 源文件
        src = os.path.join(path, str(file[i]))
        a = 1 + i
        # 对应的xml文件复制到新的路径中,并制定新的名称
        # shutil.copy(xml_file, os.path.join(new_xml, '0' + str(a) + '.xml')) #如果想新生成的名字不是1,2,3,4... 而是01,02,03..  只需要把该行代码解除注释
        shutil.copy(xml_file, os.path.join(new_xml, str(a) + '.xml'))
        # 新图像对应的名字及路径
        # new = os.path.join(new_path, '0' + str(a) + '.jpg')
        new = os.path.join(new_path,  str(a) + '.jpg')
        # 文件重命名,且源文件夹中无文件
        os.rename(src, new)


if __name__ == "__main__":
    path = r'E:\ML-data\VOCdevkit(new)\VOC2007\img'
    new_path = r"E:\ML-data\VOCdevkit(new)\VOC2007\JPEGImages"
    xml_path = r'E:\ML-data\VOCdevkit(new)\VOC2007\xml'
    new_xml = r'E:\ML-data\VOCdevkit(new)\VOC2007\Annotations'
    rename_file(path, new_path, xml_path, new_xml)

只需要更改为自己的文件名

注:该代码生成图片新名称时,相当于把原来文件夹的图片移到新的文件夹中,原来的文件夹的图片就没了,记得要备份一下再运行代码(xml原文件夹仍然存在,不需要备份)

2023.5.25 新增一个数据集划分代码,权当记录一下

具体是干啥用的,请看代码行注释

仔细一想,还得加上xml转txt的代码,正好配套

这是xml转txt 的代码(修改对于的路径就可),依次运行下面两个代码就会得到如图的文件夹

yolo数据增强以及批量修改图片和xml名_第1张图片

 

# -*- coding: utf-8 -*-
import xml.etree.ElementTree as ET
import os
from os import getcwd

sets = ['train', 'val', 'test']
classes = ["buds"]  # 改成自己的类别
abs_path = os.getcwd()
print(abs_path)


def convert(size, box):
    dw = 1. / (size[0])
    dh = 1. / (size[1])
    x = (box[0] + box[1]) / 2.0 - 1
    y = (box[2] + box[3]) / 2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return x, y, w, h


def convert_annotation(image_id):
    in_file = open(r'E:\Deep learning\yoloair-iscyy-beta\VOCdata\Annotations\%s.xml' % (image_id), encoding='UTF-8')
    out_file = open(r'E:\Deep learning\yoloair-iscyy-beta\VOCData\labels\%s.txt' % (image_id), 'w')
    tree = ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)
    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        # difficult = obj.find('Difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult) == 1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
             float(xmlbox.find('ymax').text))
        b1, b2, b3, b4 = b
        # 标注越界修正
        if b2 > w:
            b2 = w
        if b4 > h:
            b4 = h
        b = (b1, b2, b3, b4)
        bb = convert((w, h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')


wd = getcwd()
for image_set in sets:
    if not os.path.exists('E:\Deep learning\yoloair-iscyy-beta\VOCData/labels/'):
        os.makedirs('E:\Deep learning\yoloair-iscyy-beta/VOCData/labels/')
    image_ids = open('E:\Deep learning\yoloair-iscyy-beta/VOCData/ImageSets/Main/%s.txt' % (image_set)).read().strip().split()

    if not os.path.exists('E:\Deep learning\yoloair-iscyy-beta/VOCData/dataSet_path/'):
        os.makedirs('E:\Deep learning\yoloair-iscyy-beta/VOCData/dataSet_path/')

    list_file = open('dataSet_path/%s.txt' % (image_set), 'w')
    # 这行路径不需更改,这是相对路径
    for image_id in image_ids:
        list_file.write('E:\Deep learning\yoloair-iscyy-beta/VOCData/images/%s.jpg\n' % (image_id))
        convert_annotation(image_id)
    list_file.close()
# coding:utf-8

import os
import random
import argparse

parser = argparse.ArgumentParser()
#xml文件的地址,根据自己的数据进行修改 xml一般存放在Annotations下
parser.add_argument('--xml_path', default='Annotations', type=str, help='input xml label path')
#数据集的划分,地址选择自己数据下的ImageSets/Main
parser.add_argument('--txt_path', default='ImageSets/Main', type=str, help='output txt label path')
opt = parser.parse_args()

trainval_percent = 1.0  # 训练集和验证集所占比例。 这里没有划分测试集
train_percent = 0.9     # 训练集所占比例,可自己进行调整
xmlfilepath = opt.xml_path
txtsavepath = opt.txt_path
total_xml = os.listdir(xmlfilepath)
if not os.path.exists(txtsavepath):
    os.makedirs(txtsavepath)

num = len(total_xml)
list_index = range(num)
tv = int(num * trainval_percent)
tr = int(tv * train_percent)
trainval = random.sample(list_index, tv)
train = random.sample(trainval, tr)

file_trainval = open(txtsavepath + '/trainval.txt', 'w')
file_test = open(txtsavepath + '/test.txt', 'w')
file_train = open(txtsavepath + '/train.txt', 'w')
file_val = open(txtsavepath + '/val.txt', 'w')

for i in list_index:
    name = total_xml[i][:-4] + '\n'
    if i in trainval:
        file_trainval.write(name)
        if i in train:
            file_train.write(name)
        else:
            file_val.write(name)
    else:
        file_test.write(name)

file_trainval.close()
file_train.close()
file_val.close()
file_test.close()

这是把数据集和标签分到train和val两个文件夹中的代码

# 在使用v5或者v7时,数据集的格式大部分都是如下所示
# VOCdevkit
#     |---- Annotations (存放的xml)
#     |---- JEPGImages (存放的数据集图片)
# 
# 但像v6,你就必须把你的数据集弄成如下格式
# mydata
#    |----- images
#             |---- train
#             |---- val
#    |----- labels
#             |---- train
#             |---- val
# 所以需要用的以下代码来把图片和标签放在上边的文件夹里

import os
import shutil

# 读入分类的标签txt文件
label_file = open(r"E:\OtherModles\YOLOv6-main\VOCdata\ImageSets\Main\train.txt", 'r')  
# 原始文件的根目录  JPEGImages 划分图片/ labels 划分标签
input_path = "E:\OtherModles\YOLOv6-main\VOCdata\labels"
# 保存文件的根目录
output_path = "E:\OtherModles\YOLOv6-main\My_DATA/labels/train"

# 一行行读入标签文件
data = label_file.readlines()
# 计数用
i = 1
# 遍历数据
for line in data:
    # 通过空格拆分成数组
    str1 = line.split(" ")
    # 第一个是文件名
    file_name = str1[0].strip()
    # 原始文件的路径
    old_file_path = os.path.join(input_path, file_name + ".jpg")
    # old_file_path = os.path.join(input_path, file_name + ".txt")  # 当你需要把label放在train/val两个文件夹时,解除注释
    # 新文件路径
    new_file_path = output_path

    # 如果路径不存在,则创建
    if not os.path.exists(new_file_path):
        print("路径 " + new_file_path + " 不存在,正在创建......")
        os.makedirs(new_file_path)

    # 新文件位置
    new_file_path = os.path.join(new_file_path, file_name + ".png")
    # new_file_path = os.path.join(new_file_path, file_name + ".txt")     # 当你需要把label放在train/val两个文件夹时,解除注释
    print("" + str(i) + "\t正在将 " + old_file_path + " 复制到 " + new_file_path)
    # 复制文件
    shutil.copyfile(old_file_path, new_file_path)

    i = i + 1
# 完成提示
print("完成")

你可能感兴趣的:(YOLO,xml)