VOC格式数据集数据处理小工具(Python脚本)

目录

  • 1. 源代码
    • 1.1 接口说明
    • 1.2 代码
  • 2. Reference

1. 源代码

1.1 接口说明

parse_vocxml: 解析voc_xml文件,返回一个列表bboxes = [bbox_1, bbox_2, …],其中边界框bbox = [cls, x_min, y_min, x_max, y_max],参数cls是类别class缩写,即返回所有的标注框。
del_specific_cls: 删除voc_xml文件中的指定类别标注框,参数clss是classes的缩写,数据类型为set。
change_cls_name: 将voc_xml文件中某一类别的旧名称替换成指定的新名称,参数cls_old2new_dict的数据类型是dictionary。
merge_xmls_for_same_image: 将同一幅图片的多个voc_xml文件合并为一个voc_xml文件,参数args = [xml_save_path, xml_1_path, xml_2_path, …],其中xml_save_path为输出xml的路径,剩余的为需要合并的xml的路径。需求:一个人负责标记类别A,另一个人负责标记类别B,最后需要整理合并。

以下代码的接口仅处理一个voc_xml文件,想要处理整个数据集可考虑如下方法:

# An example
xml_paths = [(dir_path + var) for var in os.listdir(dir_path) if var.endswith('.xml')]
for xml_path in xml_paths:
        del_specific_cls(xml_path, {
     'cls_1', 'cls_2'})

1.2 代码

import os
import xml.etree.ElementTree as ET


def parse_vocxml(xml_path):
    '''
    return bboxes = [bbox_1, bbox_2, ...]
    where bbox = [cls, x_min, y_min, x_max, y_max].
    '''
    if not os.path.exists(xml_path):
        raise FileNotFoundError

    tree = ET.parse(xml_path)
    bboxes = []

    for var in tree.iter():
        if var.tag == 'object':
            cls, x_min, y_min, x_max, y_max = None, None, None, None, None
            for element in list(var):
                if element.tag == 'name':
                    cls = element.text
                elif element.tag == 'bndbox':
                    for coordinate in list(element):
                        if coordinate.tag == 'xmin':
                            x_min = int(coordinate.text)
                        elif coordinate.tag == 'ymin':
                            y_min = int(coordinate.text)
                        elif coordinate.tag == 'xmax':
                            x_max = int(coordinate.text)
                        elif coordinate.tag == 'ymax':
                            y_max = int(coordinate.text)
            bbox = [cls, x_min, y_min, x_max, y_max]
            bboxes.append(bbox)

    return bboxes


def del_specific_cls(xml_path, clss):
    '''
    delete specific clss = set([cls_1, cls_2, ...]) from a voc-xml file.
    '''
    if os.path.exists(xml_path) == False:
        raise FileNotFoundError

    tree = ET.parse(xml_path)
    root = tree.getroot()

    annos = [anno for anno in root.iter()]
    for i, anno in enumerate(annos):
        if anno.tag == 'object':
            for element in list(anno):
                if element.tag == 'name':
                    if element.text in clss:
                        root.remove(annos[i])

    tree = ET.ElementTree(root)
    tree.write(xml_path, encoding="utf-8", xml_declaration=True)


def change_cls_name(xml_path, cls_old2new_dict):
    '''
    change cls name from cls_old to cls_new for a voc-xml file.
    cls_old2new_dict = {cls_old: cls_new} is a dictionary.
    '''
    if os.path.exists(xml_path) == False:
        raise FileNotFoundError

    tree = ET.parse(xml_path)
    root = tree.getroot()

    annos = [anno for anno in root.iter()]
    for i, anno in enumerate(annos):
        if anno.tag == 'object':
            for element in list(anno):
                if element.tag == 'name':
                    if element.text in cls_old2new_dict.keys():
                        element.text = cls_old2new_dict[element.text]

    tree = ET.ElementTree(root)
    tree.write(xml_path, encoding="utf-8", xml_declaration=True)


def merge_xmls_for_same_image(*args):
    '''
    args = (target_xml_save_path, xml_1, xml_2, ...)
    '''

    def _append_obj(root, bbox):
        '''
        bbox = [cls, x_min, y_min, x_max, y_max]
        '''
        obj = ET.Element('object')
        name = ET.SubElement(obj, 'name')
        name.text = bbox[0]
        pose = ET.SubElement(obj, 'pose')
        pose.text = 'Unspecified'
        truncated = ET.SubElement(obj, 'truncated')
        truncated.text = '0'
        difficult = ET.SubElement(obj, 'difficult')
        difficult.text = '0'
        bndbox = ET.SubElement(obj, 'bndbox')
        xmin = ET.SubElement(bndbox, 'xmin')
        xmin.text = str(bbox[1])
        ymin = ET.SubElement(bndbox, 'ymin')
        ymin.text = str(bbox[2])
        xmax = ET.SubElement(bndbox, 'xmax')
        xmax.text = str(bbox[3])
        ymax = ET.SubElement(bndbox, 'ymax')
        ymax.text = str(bbox[4])
        root.append(obj)
        return root

    if args == None or len(args) < 2:
        raise Exception('args is None, or len(args) < 2.')

    target_xml_save_path = args[0]

    tree = ET.parse(args[1])
    root = tree.getroot()

    for arg in args[2:]:
        bboxes = parse_vocxml(arg)
        for bbox in bboxes:
            _append_obj(root, bbox)

    tree.write(target_xml_save_path, encoding='utf-8', xml_declaration=True)


if __name__ == '__main__':
    dir_path = ''
    xml_paths = [(dir_path + var) for var in os.listdir(dir_path) if var.endswith('.xml')]
    for xml_path in xml_paths:
        del_specific_cls(xml_path, {
     'cls_1', 'cls_2'})

2. Reference

① VOC格式数据集操作类构建-3.删除指定类别标签和修改指定标签类别名称
② Github项目地址(附有使用说明书):https://github.com/A-mockingbird/VOCtype-datasetOperation)

你可能感兴趣的:(研发进阶)