简易的VOC转yolo的代码

简易的VOC转yolo的代码

import xml.dom.minidom as xmldom
import os

def parse_xml(fn):
    xml_file = xmldom.parse(fn)
    eles = xml_file.documentElement
    print(eles.tagName)
    xmin = eles.getElementsByTagName("xmin")[1].firstChild.data
    xmax = eles.getElementsByTagName("xmax")[1].firstChild.data
    ymin = eles.getElementsByTagName("ymin")[1].firstChild.data
    ymax = eles.getElementsByTagName("ymax")[1].firstChild.data
    print(xmin, xmax, ymin, ymax)
    return xmin, xmax, ymin, ymax

# 获取类别数
# 输入是(包含所有xml文件的)文件夹
def AllClassName(path):
    acn = []
    for file in os.listdir(path):
        filepath = path + file
        xml_file = xmldom.parse(filepath)
        eles = xml_file.documentElement
        # print(eles.tagName)
        
        for idx in range(len(eles.getElementsByTagName("name"))):
            name = eles.getElementsByTagName("name")[idx].firstChild.data
            if name not in acn:
                acn.append(name)
            # print(name)
    return acn

def list2dict(lis):
    d = {}
    for idx in range(len(lis)):
        d[lis[idx]] = idx
    return d
        

def test_parse_xml():
    parse_xml('/root/teamshare/yolov7/customdata/labels/train/guangjiao_0.xml')

    
def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    x = (box[0] + box[2]) / 2.0
    y = (box[1] + box[3]) / 2.0
    w = box[2] - box[0]
    h = box[3] - box[1]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)

def FormatConversion(filepath, dicts):
    # 提供文件路径,标签字典
    # 会保存同名的yolo标签
    xml_file = xmldom.parse(filepath)
    txt_file = open(filepath.split('.')[0] + '.txt', 'w')
    eles = xml_file.documentElement
    w = eles.getElementsByTagName("width")[0].firstChild.data
    h = eles.getElementsByTagName("height")[0].firstChild.data
    d = eles.getElementsByTagName("depth")[0].firstChild.data
    size = [w, h, d]
    size = list(map(int, size))
    if size[2] != 3:
        raise Exception('图像通道维度不对!')
    res = []
    for idx in range(len(eles.getElementsByTagName("xmin"))):
        xmin = eles.getElementsByTagName("xmin")[idx].firstChild.data
        xmax = eles.getElementsByTagName("xmax")[idx].firstChild.data
        ymin = eles.getElementsByTagName("ymin")[idx].firstChild.data
        ymax = eles.getElementsByTagName("ymax")[idx].firstChild.data
        name = dicts[eles.getElementsByTagName("name")[idx].firstChild.data]
        bbox = [xmin, ymin, xmax, ymax]
        bbox = list(map(int, bbox))
        bbox = list(convert(size, bbox))
        res.append([name]+bbox)
        txt_file.write(str(name) + " " + " ".join([str(a) for a in bbox]) + '\n')
    txt_file.close()
    return res
        
def main(path):
    # 提供(包含所有xml文件的)文件夹
    # ==========================
    # 获得所有标签的名字
    acn = AllClassName(path)
    dicts = list2dict(acn)
    for file in os.listdir(path):
        filepath = path + file
        res = FormatConversion(filepath, dicts)
        print(res)
        os.remove(filepath)
    
    
if __name__ == "__main__":
    # test_parse_xml()
    main('/root/teamshare/yolov7/customdata/labels/train/')

你可能感兴趣的:(python,目标检测,计算机视觉)