Python目标检测数据集格式处理,VOC格式转YOLO格式

众所周知,CV算法模型训练第一步该做的是数据集制作,最近遇到需要将VOC格式的数据集转为yolo格式,数据集前期的一些预处理参考博客:Python删除txt文档的某一列_fengfeng18k的博客-CSDN博客

Python修改txt某列元素值,图片重命名_fengfeng18k的博客-CSDN博客

用Python实现:voc2yolo.py

"""
本脚本有两个功能:
1.根据train.txt和val.txt将voc数据集标注信息(.xml)转为yolo标注格式(.txt),生成dataset文件(train+val)
2.根据json标签文件,生成对应names标签(my_data_label.names)
"""
import os
from tqdm import tqdm
from lxml import etree
import json
import shutil
from os.path import *

# --------------------------全局地址变量--------------------------------#
# 拼接出voc的images目录,xml目录,txt目录
# dir_path = dirname(dirname(abspath(__file__)))
# images_path = os.path.join(dir_path, "ApplePest", "images")
images_path = "C:/Users/10974/Desktop/YH/DATASET/fire-smoke/images/"
xml_path = "C:/Users/10974/Desktop/YH/DATASET/fire-smoke/annotations/"
# xml_path = os.path.join(dir_path, "ApplePest", "Annotations")
# train_txt_path = os.path.join(dir_path, "ApplePest", "ImageSets", "train.txt")
train_txt_path = "C:/Users/10974/Desktop/YH/DATASET/fire-smoke/train.txt"
val_txt_path = "C:/Users/10974/Desktop/YH/DATASET/fire-smoke/val.txt"
# val_txt_path = os.path.join(dir_path, "ApplePest", "ImageSets", "val.txt")
# label标签对应json文件
# label_json_path = os.path.join(dir_path, "apple_pest_classes.json")
label_json_path = "C:/Users/10974/Desktop/YH/DATASET/fire-smoke/labels.json"

# save_file_root = os.path.join(dir_path, "dataset")
save_file_root = "C:/Users/10974/Desktop/YH/DATASET/fire-smoke/dataset"

# 检查文件/文件夹都是否存在
assert os.path.exists(images_path), "images path not exist..."
assert os.path.exists(xml_path), "xml path not exist..."
assert os.path.exists(train_txt_path), "train txt file not exist..."
assert os.path.exists(val_txt_path), "val txt file not exist..."
assert os.path.exists(label_json_path), "label_json_path does not exist..."

if os.path.exists(save_file_root) is False:
    os.makedirs(save_file_root)
# --------------------------全局地址变量--------------------------------#




def parse_xml_to_dict(xml):
    """
    将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
    Args:
        xml: xml tree obtained by parsing XML file contents using lxml.etree

    Returns:
        Python dictionary holding XML contents.
    """

    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        return {xml.tag: xml.text}

    result = {}
    for child in xml:
        child_result = parse_xml_to_dict(child)  # 递归遍历标签信息
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    return {xml.tag: result}


def translate_info(file_names: list, save_root: str, class_dict: dict, train_val='train'):
    """
    将对应xml文件信息转为yolo中使用的txt文件信息
    :param file_names:
    :param save_root:
    :param class_dict:
    :param train_val:
    :return:
    """
    save_txt_path = os.path.join(save_root, train_val, "labels")
    if os.path.exists(save_txt_path) is False:
        os.makedirs(save_txt_path)
    save_images_path = os.path.join(save_root, train_val, "images")
    if os.path.exists(save_images_path) is False:
        os.makedirs(save_images_path)

    for file in tqdm(file_names, desc="translate {} file...".format(train_val)):
        # 检查下图像文件是否存在
        img_path = os.path.join(images_path, file + ".jpg")
        # img_path = os.path.join(images_path, file)
        assert os.path.exists(img_path), "file:{} not exist...".format(img_path)

        # 检查xml文件是否存在
        xml_full_path = os.path.join(xml_path, file + ".xml")
        # xml_full_path = os.path.join(xml_path, file)
        assert os.path.exists(xml_full_path), "file:{} not exist...".format(xml_full_path)

        # read xml
        with open(xml_full_path,encoding='utf-8') as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        # data = parse_xml_to_dict(xml)["annotation"]
        data = parse_xml_to_dict(xml)["annotation"]
        img_height = int(data["size"]["height"])
        img_width = int(data["size"]["width"])

        # write object info into txt
        with open(os.path.join(save_txt_path, file + ".txt"), "w") as f:
            assert "object" in data.keys(), "file: '{}' lack of object key.".format(xml_full_path)
            for index, obj in enumerate(data["object"]):
                # 获取每个object的box信息
                xmin = float(obj["bndbox"]["xmin"])
                xmax = float(obj["bndbox"]["xmax"])
                ymin = float(obj["bndbox"]["ymin"])
                ymax = float(obj["bndbox"]["ymax"])
                class_name = obj["name"]
                # class_index = class_dict[class_name] - 1  # 目标id从0开始
                class_index = class_dict[class_name]  # 目标id从0开始

                # 将box信息转换到yolo格式
                xcenter = xmin + (xmax - xmin) / 2
                ycenter = ymin + (ymax - ymin) / 2
                w = xmax - xmin
                h = ymax - ymin

                # 绝对坐标转相对坐标,保存6位小数
                xcenter = round(xcenter / img_width, 6)
                ycenter = round(ycenter / img_height, 6)
                w = round(w / img_width, 6)
                h = round(h / img_height, 6)

                info = [str(i) for i in [class_index, xcenter, ycenter, w, h]]

                if index == 0:
                    f.write(" ".join(info))
                else:
                    f.write("\n" + " ".join(info))

        # copy image into save_images_path
        # shutil.copyfile(img_path, os.path.join(save_images_path, img_path.split(os.sep)[-1]))
        shutil.copyfile(img_path, os.path.join(save_images_path, img_path.split("/")[-1]))


def create_class_names(class_dict: dict):
    keys = class_dict.keys()
    with open("../dataset_classes.names", "w") as w:
        for index, k in enumerate(keys):
            if index + 1 == len(keys):
                w.write(k)
            else:
                w.write(k + "\n")


def main():
    # read class_indict
    json_file = open(label_json_path, 'r')
    class_dict = json.load(json_file)
    # class_dict = "{fire,smoke}"

    # 读取train.txt中的所有行信息,删除空行
    with open(train_txt_path, "r") as r:
        train_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
    # voc信息转yolo,并将图像文件复制到相应文件夹
    translate_info(train_file_names, save_file_root, class_dict, "train")

    # 读取val.txt中的所有行信息,删除空行
    with open(val_txt_path, "r") as r:
        val_file_names = [i for i in r.read().splitlines() if len(i.strip()) > 0]
    # voc信息转yolo,并将图像文件复制到相应文件夹
    translate_info(val_file_names, save_file_root, class_dict, "val")

    # 创建my_data_label.names文件
    create_class_names(class_dict)


if __name__ == "__main__":
    main()

参考:【YOLO-V3-SPP源码解读】一、数据集制作和格式处理_满船清梦压星河HK的博客-CSDN博客第一步、制作自己的数据集第一步是制作自己的数据集(照片),可以是网络找的,也可以是自己拍的,甚至可以是自己p的。以我下面讲解的数据集为例子,我是在网上找的关于的苹果的病虫害,我简单的做了三个分类,分别是Alternaria_Boltch(斑点落叶病)、Grey_spot(灰斑病)、Rust( 锈病)。我的文件结构如下:每个文件下放着我的数据集照片:就不一一展示了,反正就是有几个类就创几个文件夹,再把各个类别的照片放进对应的文件夹中,这样我们的数据集就初步制作完毕了。第二步、为自己的数据集打标签https://blog.csdn.net/qq_38253797/article/details/117398563

你可能感兴趣的:(Python,目标检测,python)