YOLO数据集自动生成脚本

YOLO数据集自动生成脚本

通过labelImage生成的xml转换为YOLO需要的数据集。
脚本可以自动创建YOLO的目录,并且移动文件、生成YOLO格式数据。

"""
labelImg标注的图片自动生成YOLO需要的数据集、生成训练集、验证集、测试集。
"""

import os
import math
import shutil
import xmltodict
from loguru import logger


class YoloData:
    """
    将labelImg生成的xml格式转换为YOLO数据集
    """
    classes = []
    name_map = {}

    @staticmethod
    def load_xml_file(filename: str) -> dict:
        with open(filename, mode="r", encoding='utf8') as f:
            xml = f.read()
        return xmltodict.parse(xml)['annotation']

    @classmethod
    def to_yolo_data(cls, annotation: dict) -> str:
        filename = annotation['filename']
        width, height = (float(annotation['size']["width"]), float(annotation['size']["height"]))
        classes = annotation['object']
        # 遍历所有的分类
        lines = []
        for c in classes:
            name = c['name']
            box = c['bndbox']
            x1, y1, x2, y2 = (float(box['xmin']), float(box['ymin']), float(box["xmax"]), float(box["ymax"]))
            x = f"{(x1 + x2) / (2 * width):.6f}"
            y = f"{(y1 + y2) / (2 * height):.6f}"
            w = f"{(x2 - x1) / width:.6f}"
            h = f"{(y2 - y1) / height:.6f}"
            line = ' '.join([str(len(cls.classes)), x, y, w, h])
            lines.append(line)
            if name not in cls.classes:
                cls.classes.append(name)
        return "\n".join(lines)

    @classmethod
    def save_classes(cls, path: str):
        with open(path, mode='w', encoding='utf8') as f:
            f.write("\n".join(cls.classes))

    @classmethod
    def main(cls, img_path: str, xml_path: str, folder: str = "coco", per="7:2:1"):
        """
        构建图片与xml文件路径关系
        :param img_path: 所有标注图片的路径
        :param xml_path: 所有标注好图片的xml路径
        :param folder: 在当前目录下自动生成的目录,如果目录存在会被级联删除(特别注意该目录会被删除,不要保存重要数据)
        :param per: 训练集、验证集、测试集 图片占比
        :return:
        """
        # 提前删除文件目录
        if os.path.exists(folder):
            shutil.rmtree(folder)

        # 创建目录结构
        paths = ["label/train", "label/val", "label/test", "images/train", "images/val", "images/test"]
        for path in paths:
            path = os.path.abspath(os.path.join(folder, path))
            if not os.path.exists(path):
                os.makedirs(path)

        # 关联图片与xml文件
        img_map = {".".join(i.split('.')[:-1]): os.path.join(img_path, i) for i in os.listdir(img_path)}
        xml_map = {".".join(i.split('.')[:-1]): os.path.join(xml_path, i) for i in os.listdir(xml_path)}
        path_map = {img_map[name]: xml_map[name] for name in img_map if name in xml_map}

        # 计算训练、验证、测试集占比
        pers = [float(i) for i in per.split(":")]
        train = math.ceil(len(path_map) * pers[0] / sum(pers))
        val = math.ceil(len(path_map) * pers[1] / sum(pers))
        logger.debug(f"{train}, {val}, {train + val},{len(path_map)}")

        # 生成YOLO文件、并且复制图片
        for i, (img_path, xml_path) in enumerate(path_map.items()):
            logger.debug(f"正在转换第【{i}】个xml, {xml_path}")
            annotation = cls.load_xml_file(xml_path)
            text = cls.to_yolo_data(annotation)
            if i < train:
                mid_folder = "train"
            elif train <= i < train + val:
                mid_folder = "val"
            else:
                mid_folder = "test"

            # 构建目录移动文件
            label_path = os.path.abspath(os.path.join(folder, "label", mid_folder, f"{i}.txt"))
            img_name = img_path.split(".")[-1]
            img_dst_path = os.path.abspath(os.path.join(folder, "images", mid_folder, f"{i}.{img_name}"))
            logger.debug(img_dst_path)

            # 生成文件
            shutil.copy(img_path, img_dst_path)
            with open(label_path, mode='w', encoding='utf8') as f:
                f.write(text)

            # 保存训练、验证、测试图片路径
            img_path = os.path.join(*img_dst_path.split(os.path.sep)[-4:])
            img_txt_path = os.path.join(folder, "label", f"{mid_folder}_list.txt")
            with open(img_txt_path, mode='a', encoding='utf8') as f:
                f.write(f"{img_path}\n")

        # 保存classes文件
        classes_path = os.path.abspath(os.path.join(folder, "classes.txt"))
        cls.save_classes(classes_path)
        logger.debug(f"训练集【{train}】张,验证集【{val}】张,测试集【{len(path_map) - train - val}】张,共有【{len(cls.classes)}】个分类")


if __name__ == '__main__':
    YoloData.main(r"D:\dataset\ctrip\JPEGImages", r"D:\dataset\ctrip\Annotations", folder="coco", per="8:1.5:0.5")

.

你可能感兴趣的:(python,YOLO,python,深度学习)