Python Yolov5数据集转Voc数据集

把yoloV5识别出的结果文本转换至voc格式数据集

代码:

import os
import os.path
from PIL import Image
from xml.dom.minidom import Document


def write_xml(tmp, image, w, h, obj_bud, wx_ml, classes):
    doc = Document()
    # owner
    annotation = doc.createElement('annotation')
    doc.appendChild(annotation)
    # owner
    folder = doc.createElement('folder')
    annotation.appendChild(folder)
    folder_txt = doc.createTextNode("VOC2005")
    folder.appendChild(folder_txt)

    filename = doc.createElement('filename')
    annotation.appendChild(filename)
    filename_txt = doc.createTextNode(image)
    filename.appendChild(filename_txt)
    # ones#
    source = doc.createElement('source')
    annotation.appendChild(source)

    database = doc.createElement('database')
    source.appendChild(database)
    database_txt = doc.createTextNode("The VOC2005 Database")
    database.appendChild(database_txt)

    annotation_new = doc.createElement('annotation')
    source.appendChild(annotation_new)
    annotation_new_txt = doc.createTextNode("PASCAL VOC2005")
    annotation_new.appendChild(annotation_new_txt)

    image = doc.createElement('image')
    source.appendChild(image)
    image_txt = doc.createTextNode("flickr")
    image.appendChild(image_txt)

    size = doc.createElement('size')
    annotation.appendChild(size)

    width = doc.createElement('width')
    size.appendChild(width)
    width_txt = doc.createTextNode(str(w))
    width.appendChild(width_txt)

    height = doc.createElement('height')
    size.appendChild(height)
    height_txt = doc.createTextNode(str(h))
    height.appendChild(height_txt)

    depth = doc.createElement('depth')
    size.appendChild(depth)
    depth_txt = doc.createTextNode("3")
    depth.appendChild(depth_txt)

    segmented = doc.createElement('segmented')
    annotation.appendChild(segmented)
    segmented_txt = doc.createTextNode("0")
    segmented.appendChild(segmented_txt)

    for i in range(0, int(len(obj_bud))):
        obj_buds = obj_bud[i].split(' ')
        # threes#
        object_new = doc.createElement("object")
        annotation.appendChild(object_new)

        name = doc.createElement('name')
        object_new.appendChild(name)
        name_txt = doc.createTextNode(classes[int(float(obj_buds[0]))])
        name.appendChild(name_txt)

        pose = doc.createElement('pose')
        object_new.appendChild(pose)
        pose_txt = doc.createTextNode("Unspecified")
        pose.appendChild(pose_txt)

        truncated = doc.createElement('truncated')
        object_new.appendChild(truncated)
        truncated_txt = doc.createTextNode("0")
        truncated.appendChild(truncated_txt)

        difficult = doc.createElement('difficult')
        object_new.appendChild(difficult)
        difficult_txt = doc.createTextNode("0")
        difficult.appendChild(difficult_txt)
        # threes-1#
        bnd_box = doc.createElement('bndbox')
        object_new.appendChild(bnd_box)

        x_min = doc.createElement('xmin')
        bnd_box.appendChild(x_min)
        x_min_txt = doc.createTextNode(str(int((float(obj_buds[1]) * w - float(obj_buds[3]) * w / 2.0))))
        x_min.appendChild(x_min_txt)

        y_min = doc.createElement('ymin')
        bnd_box.appendChild(y_min)
        y_min_txt = doc.createTextNode(str(int(float(obj_buds[2]) * h - float(obj_buds[4]) * h / 2.0)))
        y_min.appendChild(y_min_txt)

        x_max = doc.createElement('xmax')
        bnd_box.appendChild(x_max)
        x_max_txt = doc.createTextNode(str(int(float(obj_buds[1]) * w + float(obj_buds[3]) * w / 2.0)))
        x_max.appendChild(x_max_txt)

        y_max = doc.createElement('ymax')
        bnd_box.appendChild(y_max)
        y_max_txt = doc.createTextNode(str(int(float(obj_buds[2]) * h + float(obj_buds[4]) * h / 2)))
        y_max.appendChild(y_max_txt)

    temple = tmp + "test.xml"
    with open(temple, "w") as f:
        f.write(doc.toprettyxml(indent='\t'))

    rewrite = open(temple, "r")
    lines = rewrite.read().split('\n')
    newlines = lines[1:len(lines) - 1]

    fw = open(wx_ml, "w")
    for i in range(0, len(newlines)):
        fw.write(newlines[i] + '\n')

    fw.close()
    rewrite.close()
    os.remove(temple)
    return


def yolo_to_voc(labels_path, xml_path, img_path, classes):
    if not os.path.exists(xml_path):
        os.mkdir(xml_path)
    temp = "temp/"
    if not os.path.exists(temp):
        os.mkdir(temp)
    for files in os.walk(labels_path):
        for file in files[2]:
            print(file + "-->start!")
            img_name = os.path.splitext(file)[0] + '.jpg'
            file_img_path = img_path + img_name
            im = Image.open(file_img_path)
            width = int(im.size[0])
            height = int(im.size[1])

            file_label = open(labels_path + file, "r")
            lines = file_label.read().split('\n')
            obj = lines[:len(lines) - 1]
            #         print(obj)
            filename = xml_path + os.path.splitext(file)[0] + '.xml'
            write_xml(temp, img_name, width, height, obj, filename, classes)
    os.rmdir(temp)


if __name__ == '__main__':
    # yolo标注.txt文件夹
    labels_path_ = "Annotations/labels/"
    # 对应的图片文件夹
    img_path_ = 'Annotations/images/'
    # 这里要改成转换后保存的数据集路径
    xml_path_ = 'Annotations/xml/'
    # 标签

    Class_Name = ['M', 'L', 't']
    yolo_to_voc(labels_path_, xml_path_, img_path_, Class_Name)

你可能感兴趣的:(python,开发语言)