通过两个类来转换
import os
from xml.etree.ElementTree import dump
import json
import pprint
import sys
import argparse
import xml.etree.ElementTree as Et
from xml.etree.ElementTree import Element, ElementTree
import cv2
class VOC:
"""
Handler Class for VOC PASCAL Format
"""
def xml_indent(self, elem, level=0):
i = "\n" + level * "\t"
if len(elem):
if not elem.text or not elem.text.strip():
elem.text = i + "\t"
if not elem.tail or not elem.tail.strip():
elem.tail = i
for elem in elem:
self.xml_indent(elem, level + 1)
if not elem.tail or not elem.tail.strip():
elem.tail = i
else:
if level and (not elem.tail or not elem.tail.strip()):
elem.tail = i
def generate(self, data):
try:
xml_list = {}
for key in data:
element = data[key]
xml_annotation = Element("annotation")
xml_size = Element("size")
xml_width = Element("width")
xml_width.text = element["size"]["width"]
xml_size.append(xml_width)
xml_height = Element("height")
xml_height.text = element["size"]["height"]
xml_size.append(xml_height)
xml_depth = Element("depth")
xml_depth.text = element["size"]["depth"]
xml_size.append(xml_depth)
xml_annotation.append(xml_size)
xml_segmented = Element("segmented")
xml_segmented.text = "0"
xml_annotation.append(xml_segmented)
if int(element["objects"]["num_obj"]) < 1:
return False, "number of Object less than 1"
for i in range(0, int(element["objects"]["num_obj"])):
xml_object = Element("object")
obj_name = Element("name")
obj_name.text = element["objects"][str(i)]["name"]
xml_object.append(obj_name)
obj_pose = Element("pose")
obj_pose.text = "Unspecified"
xml_object.append(obj_pose)
obj_truncated = Element("truncated")
obj_truncated.text = "0"
xml_object.append(obj_truncated)
obj_difficult = Element("difficult")
obj_difficult.text = "0"
xml_object.append(obj_difficult)
xml_bndbox = Element("bndbox")
obj_xmin = Element("xmin")
obj_xmin.text = str(element["objects"][str(i)]["bndbox"]["xmin"])
xml_bndbox.append(obj_xmin)
obj_ymin = Element("ymin")
obj_ymin.text = str(element["objects"][str(i)]["bndbox"]["ymin"])
xml_bndbox.append(obj_ymin)
obj_xmax = Element("xmax")
obj_xmax.text = str(element["objects"][str(i)]["bndbox"]["xmax"])
xml_bndbox.append(obj_xmax)
obj_ymax = Element("ymax")
obj_ymax.text = str(element["objects"][str(i)]["bndbox"]["ymax"])
xml_bndbox.append(obj_ymax)
xml_object.append(xml_bndbox)
xml_annotation.append(xml_object)
self.xml_indent(xml_annotation)
xml_list[key.split(".")[0]] = xml_annotation
return True, xml_list
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
@staticmethod
def save(xml_list, path):
try:
path = os.path.abspath(path)
for key in xml_list:
xml = xml_list[key]
filepath = os.path.join(path, "".join([key, ".xml"]))
ElementTree(xml).write(filepath)
return True, None
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
@staticmethod
def parse(path, img_path):
try:
(dir_path, dir_names, filenames) = next(os.walk(os.path.abspath(path)))
data = {}
for filename in filenames:
xml = open(os.path.join(dir_path, filename), "r")
tree = Et.parse(xml)
root = tree.getroot()
xml_size = root.find("size")
size = {
"width": xml_size.find("width").text,
"height": xml_size.find("height").text,
"depth": xml_size.find("depth").text
}
objects = root.findall("object")
if len(objects) == 0:
return False, "number object zero"
obj = {
"num_obj": len(objects)
}
obj_index = 0
for _object in objects:
tmp = {
"name": _object.find("name").text
}
xml_bndbox = _object.find("bndbox")
bndbox = {
"xmin": float(xml_bndbox.find("xmin").text),
"ymin": float(xml_bndbox.find("ymin").text),
"xmax": float(xml_bndbox.find("xmax").text),
"ymax": float(xml_bndbox.find("ymax").text)
}
tmp["bndbox"] = bndbox
obj[str(obj_index)] = tmp
obj_index += 1
if obj_index < 1:
print('xml has no obj: {}'.format(os.path.join(dir_path, filename)))
continue
if not os.path.exists(os.path.join(img_path, filename.replace('.xml', '.jpg'))):
print('img not exists : {}'.format(os.path.join(img_path, filename.replace('.xml', '.jpg'))))
annotation = {
"img_path": os.path.join(img_path, filename.replace('.xml', '.jpg')),
"size": size,
"objects": obj
}
data[filename[:-4]] = annotation
return True, data
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
class YOLO:
"""
Handler Class for UDACITY Format
"""
def __init__(self, cls_list_path):
with open(cls_list_path, 'r') as file:
l = file.read().splitlines()
self.cls_list = l
def coordinateCvt2YOLO(self,size, box):
dw = 1. / size[0]
dh = 1. / size[1]
# (xmin + xmax / 2)
x = (box[0] + box[1]) / 2.0
# (ymin + ymax / 2)
y = (box[2] + box[3]) / 2.0
# (xmax - xmin) = w
w = box[1] - box[0]
# (ymax - ymin) = h
h = box[3] - box[2]
x = x * dw
w = w * dw
y = y * dh
h = h * dh
return (round(x,10), round(y,10), round(w,10), round(h,10))
def parse(self, label_path, img_path, img_type=".png"):
try:
(dir_path, dir_names, filenames) = next(os.walk(os.path.abspath(label_path)))
data = {}
progress_length = len(filenames)
progress_cnt = 0
printProgressBar(0, progress_length, prefix='\nYOLO Parsing:'.ljust(15), suffix='Complete', length=40)
for filename in filenames:
txt = open(os.path.join(dir_path, filename), "r")
filename = filename.split(".")[0]
img = Image.open(os.path.join(img_path, "".join([filename, img_type])))
img_width = str(img.size[0])
img_height = str(img.size[1])
img_depth = 3
size = {
"width": img_width,
"height": img_height,
"depth": img_depth
}
obj = {}
obj_cnt = 0
for line in txt:
elements = line.split(" ")
name_id = elements[0]
xminAddxmax = float(elements[1]) * (2.0 * float(img_width))
yminAddymax = float(elements[2]) * (2.0 * float(img_height))
w = float(elements[3]) * float(img_width)
h = float(elements[4]) * float(img_height)
xmin = (xminAddxmax - w) / 2
ymin = (yminAddymax - h) / 2
xmax = xmin + w
ymax = ymin + h
bndbox = {
"xmin": float(xmin),
"ymin": float(ymin),
"xmax": float(xmax),
"ymax": float(ymax)
}
obj_info = {
"name": name_id,
"bndbox": bndbox
}
obj[str(obj_cnt)] =obj_info
obj_cnt += 1
obj["num_obj"] = obj_cnt
data[filename] = {
"size": size,
"objects": obj
}
return True, data
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
def generate(self, data):
try:
result = {}
for key in data:
img_width = int(data[key]["size"]["width"])
img_height = int(data[key]["size"]["height"])
contents = ""
for idx in range(0, int(data[key]["objects"]["num_obj"])):
xmin = data[key]["objects"][str(idx)]["bndbox"]["xmin"]
ymin = data[key]["objects"][str(idx)]["bndbox"]["ymin"]
xmax = data[key]["objects"][str(idx)]["bndbox"]["xmax"]
ymax = data[key]["objects"][str(idx)]["bndbox"]["ymax"]
b = (float(xmin), float(xmax), float(ymin), float(ymax))
bb = self.coordinateCvt2YOLO((img_width, img_height), b)
# print(key)
if data[key]["objects"][str(idx)]["name"] not in self.cls_list:
if 'limit' in data[key]["objects"][str(idx)]["name"]:
data[key]["objects"][str(idx)]["name"] = 'limit'
elif 'van' in data[key]["objects"][str(idx)]["name"]:
data[key]["objects"][str(idx)]["name"] = 'car'
elif 'rider' in data[key]["objects"][str(idx)]["name"]:
data[key]["objects"][str(idx)]["name"] = 'person'
elif 'wan' in data[key]["objects"][str(idx)]["name"]:
print(key, '-----------------------------------------------------')
img = cv2.imread(data[key]['img_path'])
cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 255), 3)
cv2.imshow('test', img)
cv2.waitKey(2000)
continue
else:
if 'traffic light' in data[key]["objects"][str(idx)]["name"] or 'trafiic light' in data[key]["objects"][str(idx)]["name"]:
continue
print(data[key]["objects"][str(idx)]["name"], 'not in cls list!------------')
continue
cls_id = self.cls_list.index(data[key]["objects"][str(idx)]["name"])
bndbox = "".join(["".join([str(e), " "]) for e in bb])
contents = "".join([contents, str(cls_id), " ", bndbox[:-1], "\n"])
result[key] = contents
return True, result
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
def save(self, data, save_path, img_path, img_type, manipast_path):
try:
with open(os.path.abspath(os.path.join(manipast_path, "manifast.txt")), "w") as manipast_file:
for key in data:
manipast_file.write(os.path.abspath(os.path.join(img_path, "".join([key, img_type, "\n"]))))
with open(os.path.abspath(os.path.join(save_path, "".join([key, ".txt"]))), "w") as output_txt_file:
output_txt_file.write(data[key])
return True, None
except Exception as e:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
msg = "ERROR : {}, moreInfo : {}\t{}\t{}".format(e, exc_type, fname, exc_tb.tb_lineno)
return False, msg
查看标注文件在图片上的表现
def view_yolo_txt(txt_dir, img_dir):
for idx, txtf in tqdm(enumerate(os.listdir(txt_dir))):
txtp = os.path.join(txt_dir, txtf)
with open(txtp, 'r') as f:
res1 = [[float(j) for j in x.split()] for x in f.read().splitlines()]
img = cv2.imread(os.path.join(img_dir, txtf.replace('.txt', '.jpg')))
h, w, c = img.shape
need_show = 0
for obj in res1:
cls, cx, cy, cw, ch = obj
need_show = 1
xmin = int(w * (cx - 0.5 * cw))
ymin = int(h * (cy - 0.5 * ch))
xmax = int(w * (cx + 0.5 * cw))
ymax = int(h * (cy + 0.5 * ch))
cv2.putText(img, class_idx_list[cls], (xmin, ymin-2), cv2.FONT_ITALIC, 1, (255, 255, 0), 1)
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
img = cv2.resize(img, None, fx=0.5, fy=0.5)
if need_show:
print(os.path.join(img_dir, txtf.replace('.txt', '.jpg')))
cv2.imshow('tt', img)
cv2.waitKey(0)
yolo_to_coco