深度学习Labelme的json文件转roLabelImg的xml文件

文章目录

  • 前言
  • 一、数据集准备
  • 二、标注工具下载
    • 1.源码下载
    • 2.所需依赖安装
    • 3、启动命令
    • 4、常用快捷键
  • 三、文件转换:
  • 总结


前言

最近,想将之前训练的YOLOV5模型更改为旋转目标检测(YOLOV5_obb)的,但是又不想重新标注数据集,在网上也没找到适合的脚本。所以就写了个脚本将之前标注的数据集的json文件转换为了xml文件,然后在通过对应的旋转标注软件进行调整即可。


一、数据集准备

该转换脚本的前提是你之前已经有了可供YOLOV5训练的数据集,且标注文件为json文件,标注框选择的是矩形框。这里仅实现矩形框转换为xml,如是多边形,则需要自己对脚本进行更改。数据集如下:
深度学习Labelme的json文件转roLabelImg的xml文件_第1张图片

二、标注工具下载

这里我们用到的工具是roLabelImg,该工具是在LabelImg的基础上加入了旋转角度的功能,标注数据集的格式为VOC,参考链接为:roLabelImg安装。

1.源码下载

下载链接为:roLabelImg

2.所需依赖安装

命令如下,仅供参考:

pip install pyqt5-tools
pip install lxml

3、启动命令

命令如下:

pyrcc5 -o resources.py resources.qrc 
python roLabelImg.py(第二次直接执行这个命令应该就可以了)

4、常用快捷键

w: 创建水平矩形目标框
e: 创建旋转矩形目标框
zxcv: 旋转目标框,键z和建x是逆时针旋转,键c和键v是顺时针旋转

三、文件转换:

这里的大致思路就是先解析出原来的json文件,然后删除其中的冗余信息,如图片信息等,然后将剩余信息写入对应的xml文件中。一下代码将路径改改应该就可以运行,代码如下:

import os,glob,json
import xml.etree.ElementTree as ET
from pathlib import Path
import numpy as np

#获得中心点的坐标
def get_cx_cy(points):
    cx=(points[0][0]+points[1][0])/2.
    cy=(points[0][1]+points[1][1])/2.
    points=np.array(points)
    w=max(points[:,0])-min(points[:,0])
    h=max(points[:,1])-min(points[:,1])
    return cx,cy,w,h

#美化内容(+换行)
def indent(elem,level=0):
    i = "\n" + level * "  "
    if len(elem):
        if not elem.text or not elem.text.strip():
            elem.text = i + "  "
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
        for elem in elem:
            indent(elem, level + 1)
        if not elem.tail or not elem.tail.strip():
            elem.tail = i
    else:
        if level and (not elem.tail or not elem.tail.strip()):
            elem.tail = i



#写入xml文件
def write_xml(data,imgpath,savepath):
    root=ET.Element('annotation') #创建节点
    tree=ET.ElementTree(root) #创建文档

    #图片文件上一级目录
    folder=ET.Element("folder")
    img_folder=imgpath.split(os.sep)[-2]
    folder.text=img_folder
    root.append(folder)

    #文件名
    imgname=Path(imgpath).stem
    filename=ET.Element("filename")
    filename.text=imgname
    root.append(filename)

    #路径
    path=ET.Element("path")
    path.text=imgpath
    root.append(path)

    #source
    source=ET.Element("source")
    root.append(source)
    database=ET.Element("database")
    database.text="Unknown"
    source.append(database)

    #size
    size=ET.Element("size")
    root.append(size)
    width=ET.Element("width") #宽
    width.text=str(data["imageWidth"])
    size.append(width)
    height=ET.Element("height")#高
    height.text=str(data["imageHeight"])
    size.append(height)
    depth=ET.Element("depth") #深度
    depth.text=str(3)
    size.append(depth)

    #segmented
    segmented=ET.Element("segmented")
    segmented.text=str(0)
    root.append(segmented)

    # 目标
    for shape in data["shapes"]:
        object_=ET.Element("object")
        root.append(object_)
        #标注框类型
        type_=ET.Element("type")
        type_.text="robndbox"
        object_.append(type_)
        #目标类别
        name=ET.Element("name")
        name.text=shape["label"]
        object_.append(name)
        #pose
        pose=ET.Element("pose")
        pose.text="Unspecified"
        object_.append(pose)
        #截断情况
        truncated=ET.Element("truncated")
        truncated.text=str(0) #默认为0,表示未截断
        object_.append(truncated)
        #样本困难度
        difficult=ET.Element("difficult")
        difficult.text=str(0) #默认为0,表示非困难样本
        object_.append(difficult)
        #四个端点
        robndbox=ET.Element("robndbox")
        object_.append(robndbox)
        cx,cy,w,h=get_cx_cy(shape["points"])
        #cx
        cx_=ET.Element("cx")
        cx_.text=str(cx)
        robndbox.append(cx_)
        #cy
        cy_=ET.Element("cy")
        cy_.text=str(cy)
        robndbox.append(cy_)
        #w
        w_=ET.Element("w")
        w_.text=str(w)
        robndbox.append(w_)
        #h
        h_=ET.Element("h")
        h_.text=str(h)
        robndbox.append(h_)
        #angle
        angle=ET.Element("angle")
        angle.text=str(0.0)
        robndbox.append(angle)

    indent(root,0)
    tree.write(savepath+os.sep+imgname+".xml","UTF-8",xml_declaration=True)

#解析json文件
def load_json(jsonpath):
    data=json.load(open(jsonpath,"r"))
    del data["version"]
    try:
        del data["flags"]
    except Exception as e:
        del data["flag"]
    del data["imagePath"]
    del data["imageData"]
    return data



if __name__ == '__main__':
    img_dir=r"" #要转换的原json文件的路径
    save_dir=r"" #保存文件夹路径
    os.makedirs(save_dir,exist_ok=True)
    imglist=glob.glob(img_dir+"*.jpg")
    for imgpath in imglist:
        # print(imgpath)
        jsonpath=imgpath.replace(".jpg",".json")
        data=load_json(jsonpath)

        write_xml(data,imgpath,save_dir)

如若json文件时多边形标注的,那么需要在get_cx_cy函数中进行更改


总结

以上就是本篇的全部内容,如果问题,欢迎评论区留言,或加入QQ群:995760755交流。

你可能感兴趣的:(深度学习,json,xml,计算机视觉,图像处理)