逆向将物体检测数据集生成labelme标注的数据

对一些现有的数据集进行反推,生成labelme标注的格式。生成的效果如下图:

使用了 RSOD部分数据,将VOC数据集反推为labelme的标注数据。

代码如下:

import sys
import os.path as osp
import io
from labelme.logger import logger
from labelme import PY2
from labelme import QT4
import PIL.Image
import base64
from labelme import utils
import os
import cv2
import xml.etree.ElementTree as ET

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
import json
from PIL import Image

Image.MAX_IMAGE_PIXELS = None
imageroot = 'RSOD/'
listDir = ['aircraft', 'oiltank']//数据的类别

def load_image_file(filename):
    try:
        image_pil = PIL.Image.open(filename)
    except IOError:
        logger.error('Failed opening image file: {}'.format(filename))
        return

    # apply orientation to image according to exif
    image_pil = utils.apply_exif_orientation(image_pil)

    with io.BytesIO() as f:
        ext = osp.splitext(filename)[1].lower()
        if PY2 and QT4:
            format = 'PNG'
        elif ext in ['.jpg', '.jpeg']:
            format = 'JPEG'
        else:
            format = 'PNG'
        image_pil.save(f, format=format)
        f.seek(0)
        return f.read()


def dict_json(flags, imageData, shapes, imagePath, fillColor=None, lineColor=None, imageHeight=100, imageWidth=100):
    '''
    :param imageData: str
    :param shapes: list
    :param imagePath: str
    :param fillColor: list
    :param lineColor: list
    :return: dict
    '''
    return {"version": "3.16.4", "flags": flags, "shapes": shapes, 'lineColor': lineColor, "fillColor": fillColor,
            'imagePath': imagePath.split('/')[-1], "imageData": imageData, 'imageHeight': imageHeight,
            'imageWidth': imageWidth}


data = json.load(open('1.json'))
for subPath in listDir:
    xmlpathName = imageroot + subPath + '/Annotation/xml'
    imagepath = imageroot + subPath + '/JPEGImages'
    resultFile = os.listdir(xmlpathName)
    for file in resultFile:
        print(file)
        imagePH = imagepath + '/' + file.split('.')[0] + '.jpg'
        print(imagePH)
        tree = ET.parse(xmlpathName + '/' + file)
        image = cv2.imread(imagePH)
        shapes = data["shapes"]
        version = data["version"]
        flags = data["flags"]
        lineColor = data["lineColor"]
        fillColor = data['fillColor']
        newshapes = []
        for elem in tree.iter():
            if 'object' in elem.tag:
                name = ''
                xminNode = 0
                yminNode = 0
                xmaxNode = 0
                ymaxNode = 0
                for attr in list(elem):
                    if 'name' in attr.tag:
                        name = attr.text
                    if 'bndbox' in attr.tag:
                        for dim in list(attr):
                            if 'xmin' in dim.tag:
                                xminNode = int(round(float(dim.text)))
                            if 'ymin' in dim.tag:
                                yminNode = int(round(float(dim.text)))
                            if 'xmax' in dim.tag:
                                xmaxNode = int(round(float(dim.text)))
                            if 'ymax' in dim.tag:
                                ymaxNode = int(round(float(dim.text)))
                line_color = None
                fill_color = None
                newPoints = [[float(xminNode), float(yminNode)], [float(xmaxNode), float(ymaxNode)]]
                shape_type = 'rectangle'
                flags = flags
                newshapes.append(
                    {"label": name, "line_color": line_color, "fill_color": fill_color, "points": newPoints,
                     "shape_type": shape_type, "flags": flags})
        imageData_90 = load_image_file(imagePH)
        imageData_90 = base64.b64encode(imageData_90).decode('utf-8')
        imageHeight = image.shape[0]
        imageWidth = image.shape[1]
        data_90 = dict_json(flags, imageData_90, newshapes, imagePH, fillColor, lineColor, imageHeight, imageWidth)
        json_file = imagePH[:-4] + '.json'
        json.dump(data_90, open(json_file, 'w'))

 

你可能感兴趣的:(人工智能)