数据转为YOLO的txt数据格式

通过两个类来转换

import os
from xml.etree.ElementTree import dump
import json
import pprint
import sys
import argparse
import xml.etree.ElementTree as Et
from xml.etree.ElementTree import Element, ElementTree
import cv2

class VOC:
    """
    Handler Class for VOC PASCAL Format
    """

    def xml_indent(self, elem, level=0):
        i = "\n" + level * "\t"
        if len(elem):
            if not elem.text or not elem.text.strip():
                elem.text = i + "\t"
            if not elem.tail or not elem.tail.strip():
                elem.tail = i
            for elem in elem:
                self.xml_indent(elem, level + 1)
            if not elem.tail or not elem.tail.strip():
                elem.tail = i
        else:
            if level and (not elem.tail or not elem.tail.strip()):
                elem.tail = i

    def generate(self, data):
        try:

            xml_list = {}

            for key in data:
                element = data[key]

                xml_annotation = Element("annotation")

                xml_size = Element("size")
                xml_width = Element("width")
                xml_width.text = element["size"]["width"]
                xml_size.append(xml_width)

                xml_height = Element("height")
                xml_height.text = element["size"]["height"]
                xml_size.append(xml_height)

                xml_depth = Element("depth")
                xml_depth.text = element["size"]["depth"]
                xml_size.append(xml_depth)

                xml_annotation.append(xml_size)

                xml_segmented = Element("segmented")
                xml_segmented.text = "0"

                xml_annotation.append(xml_segmented)

                if int(element["objects"]["num_obj"]) < 1:
                    return False, "number of Object less than 1"

                for i in range(0, int(element["objects"]["num_obj"])):
                    xml_object = Element("object")
                    obj_name = Element("name")
                    obj_name.text = element["objects"][str(i)]["name"]
                    xml_object.append(obj_name)

                    obj_pose = Element("pose")
                    obj_pose.text = "Unspecified"
                    xml_object.append(obj_pose)

                    obj_truncated = Element("truncated")
                    obj_truncated.text = "0"
                    xml_object.append(obj_truncated)

                    obj_difficult = Element("difficult")
                    obj_difficult.text = "0"
                    xml_object.append(obj_difficult)

                    xml_bndbox = Element("bndbox")

                    obj_xmin = Element("xmin")
                    obj_xmin.text = str(element["objects"][str(i)]["bndbox"]["xmin"])
                    xml_bndbox.append(obj_xmin)

                    obj_ymin = Element("ymin")
                    obj_ymin.text = str(element["objects"][str(i)]["bndbox"]["ymin"])
                    xml_bndbox.append(obj_ymin)

                    obj_xmax = Element("xmax")
                    obj_xmax.text = str(element["objects"][str(i)]["bndbox"]["xmax"])
                    xml_bndbox.append(obj_xmax)

                    obj_ymax = Element("ymax")
                    obj_ymax.text = str(element["objects"][str(i)]["bndbox"]["ymax"])
                    xml_bndbox.append(obj_ymax)
                    xml_object.append(xml_bndbox)

                    xml_annotation.append(xml_object)

                self.xml_indent(xml_annotation)

                xml_list[key.split(".")[0]] = xml_annotation

            return True, xml_list

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

    @staticmethod
    def save(xml_list, path):

        try:
            path = os.path.abspath(path)

            for key in xml_list:
                xml = xml_list[key]
                filepath = os.path.join(path, "".join([key, ".xml"]))
                ElementTree(xml).write(filepath)

            return True, None

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

    @staticmethod
    def parse(path, img_path):
        try:

            (dir_path, dir_names, filenames) = next(os.walk(os.path.abspath(path)))

            data = {}

            for filename in filenames:

                xml = open(os.path.join(dir_path, filename), "r")

                tree = Et.parse(xml)
                root = tree.getroot()

                xml_size = root.find("size")
                size = {
                    "width": xml_size.find("width").text,
                    "height": xml_size.find("height").text,
                    "depth": xml_size.find("depth").text

                }

                objects = root.findall("object")
                if len(objects) == 0:
                    return False, "number object zero"

                obj = {
                    "num_obj": len(objects)
                }

                obj_index = 0
                for _object in objects:
                    tmp = {
                        "name": _object.find("name").text
                    }

                    xml_bndbox = _object.find("bndbox")
                    bndbox = {
                        "xmin": float(xml_bndbox.find("xmin").text),
                        "ymin": float(xml_bndbox.find("ymin").text),
                        "xmax": float(xml_bndbox.find("xmax").text),
                        "ymax": float(xml_bndbox.find("ymax").text)
                    }
                    tmp["bndbox"] = bndbox
                    obj[str(obj_index)] = tmp

                    obj_index += 1
                if obj_index < 1:
                    print('xml has no obj: {}'.format(os.path.join(dir_path, filename)))
                    continue
                if not os.path.exists(os.path.join(img_path, filename.replace('.xml', '.jpg'))):
                    print('img not exists : {}'.format(os.path.join(img_path, filename.replace('.xml', '.jpg'))))
                annotation = {
                    "img_path": os.path.join(img_path, filename.replace('.xml', '.jpg')),
                    "size": size,
                    "objects": obj
                }

                data[filename[:-4]] = annotation

            return True, data

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg
class YOLO:
    """
    Handler Class for UDACITY Format
    """

    def __init__(self, cls_list_path):
        with open(cls_list_path, 'r') as file:
            l = file.read().splitlines()

        self.cls_list = l

    def coordinateCvt2YOLO(self,size, box):
        dw = 1. / size[0]
        dh = 1. / size[1]

        # (xmin + xmax / 2)
        x = (box[0] + box[1]) / 2.0
        # (ymin + ymax / 2)
        y = (box[2] + box[3]) / 2.0

        # (xmax - xmin) = w
        w = box[1] - box[0]
        # (ymax - ymin) = h
        h = box[3] - box[2]

        x = x * dw
        w = w * dw
        y = y * dh
        h = h * dh
        return (round(x,10), round(y,10), round(w,10), round(h,10))

    def parse(self, label_path, img_path, img_type=".png"):
        try:

            (dir_path, dir_names, filenames) = next(os.walk(os.path.abspath(label_path)))

            data = {}

            progress_length = len(filenames)
            progress_cnt = 0
            printProgressBar(0, progress_length, prefix='\nYOLO Parsing:'.ljust(15), suffix='Complete', length=40)

            for filename in filenames:

                txt = open(os.path.join(dir_path, filename), "r")

                filename = filename.split(".")[0]

                img = Image.open(os.path.join(img_path, "".join([filename, img_type])))
                img_width = str(img.size[0])
                img_height = str(img.size[1])
                img_depth = 3

                size = {
                    "width": img_width,
                    "height": img_height,
                    "depth": img_depth
                }

                obj = {}
                obj_cnt = 0

                for line in txt:
                    elements = line.split(" ")
                    name_id = elements[0]

                    xminAddxmax = float(elements[1]) * (2.0 * float(img_width))
                    yminAddymax = float(elements[2]) * (2.0 * float(img_height))

                    w = float(elements[3]) * float(img_width)
                    h = float(elements[4]) * float(img_height)

                    xmin = (xminAddxmax - w) / 2
                    ymin = (yminAddymax - h) / 2
                    xmax = xmin + w
                    ymax = ymin + h

                    bndbox = {
                        "xmin": float(xmin),
                        "ymin": float(ymin),
                        "xmax": float(xmax),
                        "ymax": float(ymax)
                    }


                    obj_info = {
                        "name": name_id,
                        "bndbox": bndbox
                    }

                    obj[str(obj_cnt)] =obj_info
                    obj_cnt += 1

                obj["num_obj"] =  obj_cnt

                data[filename] = {
                    "size": size,
                    "objects": obj
                }

            return True, data

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

    def generate(self, data):

        try:
            result = {}

            for key in data:
                img_width = int(data[key]["size"]["width"])
                img_height = int(data[key]["size"]["height"])

                contents = ""

                for idx in range(0, int(data[key]["objects"]["num_obj"])):

                    xmin = data[key]["objects"][str(idx)]["bndbox"]["xmin"]
                    ymin = data[key]["objects"][str(idx)]["bndbox"]["ymin"]
                    xmax = data[key]["objects"][str(idx)]["bndbox"]["xmax"]
                    ymax = data[key]["objects"][str(idx)]["bndbox"]["ymax"]

                    b = (float(xmin), float(xmax), float(ymin), float(ymax))
                    bb = self.coordinateCvt2YOLO((img_width, img_height), b)
                    # print(key)
                    if data[key]["objects"][str(idx)]["name"] not in self.cls_list:
                        if 'limit' in data[key]["objects"][str(idx)]["name"]:
                            data[key]["objects"][str(idx)]["name"] = 'limit'
                        elif 'van' in data[key]["objects"][str(idx)]["name"]:
                            data[key]["objects"][str(idx)]["name"] = 'car'
                        elif 'rider' in data[key]["objects"][str(idx)]["name"]:
                            data[key]["objects"][str(idx)]["name"] = 'person'
                        elif 'wan' in data[key]["objects"][str(idx)]["name"]:
                            print(key, '-----------------------------------------------------')
                            img = cv2.imread(data[key]['img_path'])
                            cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 255), 3)
                            cv2.imshow('test', img)
                            cv2.waitKey(2000)
                            continue
                        else:
                            if 'traffic light' in data[key]["objects"][str(idx)]["name"]  or 'trafiic light' in data[key]["objects"][str(idx)]["name"]:
                                continue
                            print(data[key]["objects"][str(idx)]["name"], 'not in cls list!------------')
                            continue

                    cls_id = self.cls_list.index(data[key]["objects"][str(idx)]["name"])

                    bndbox = "".join(["".join([str(e), " "]) for e in bb])
                    contents = "".join([contents, str(cls_id), " ", bndbox[:-1], "\n"])

                result[key] = contents

            return True, result

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

    def save(self, data, save_path, img_path, img_type, manipast_path):

        try:
            with open(os.path.abspath(os.path.join(manipast_path, "manifast.txt")), "w") as manipast_file:
                
                for key in data:
                    manipast_file.write(os.path.abspath(os.path.join(img_path, "".join([key, img_type, "\n"]))))

                    with open(os.path.abspath(os.path.join(save_path, "".join([key, ".txt"]))), "w") as output_txt_file:
                        output_txt_file.write(data[key])

            return True, None

        except Exception as e:

            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]

            msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)

            return False, msg

查看标注文件在图片上的表现

def view_yolo_txt(txt_dir, img_dir):
    for idx, txtf in tqdm(enumerate(os.listdir(txt_dir))):
        txtp = os.path.join(txt_dir, txtf)
        with open(txtp, 'r') as f:
            res1 = [[float(j) for j in x.split()] for x in f.read().splitlines()]
            img = cv2.imread(os.path.join(img_dir, txtf.replace('.txt', '.jpg')))
            h, w, c = img.shape
            need_show = 0
            for obj in res1:
                cls, cx, cy, cw, ch = obj
                need_show = 1
                xmin = int(w * (cx - 0.5 * cw))
                ymin = int(h * (cy - 0.5 * ch))
                xmax = int(w * (cx + 0.5 * cw))
                ymax = int(h * (cy + 0.5 * ch))
                cv2.putText(img, class_idx_list[cls], (xmin, ymin-2), cv2.FONT_ITALIC, 1, (255, 255, 0), 1)
                cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
            img = cv2.resize(img, None, fx=0.5, fy=0.5)
            if need_show:
                print(os.path.join(img_dir, txtf.replace('.txt', '.jpg')))
                cv2.imshow('tt', img)
                cv2.waitKey(0)

yolo_to_coco

你可能感兴趣的:(深度学习,深度学习)