一个标注文件的格式如下,和VOC的标注格式很像。
0a8f1803ae7a0d65c5dd5561167e6a30
1920
1080
"""
author:guopei
date:2020.02.23
"""
import os
from PIL import Image,ImageDraw,ImageFont
from tqdm import tqdm
import xml.dom.minidom
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from models.experimental import attempt_load
from utils.general import letterbox, non_max_suppression, scale_coords
class Yolov5Detect(object):
def __init__(self, weights='./weights/yolov5x.pt', device=0, img_size=800, conf=0.65, iou=0.5):
with torch.no_grad():
self.device = "cuda:%s" % device
self.model = attempt_load(weights, map_location=self.device) # load FP32 model
self.model.half() # to FP16
self.imgsz = img_size # img_size最好是32的整数倍
self.conf = conf
self.iou = iou
temp_img = torch.zeros((1, 3, self.imgsz, self.imgsz), device=self.device) # init img
_ = self.model(temp_img.half()) # run once
def pre_process(self, img_path):
img0 = cv2.imread(img_path)
assert img0 is not None, "Image Not Found " + img_path
img = letterbox(img0, new_shape=self.imgsz)[0]
# Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = np.ascontiguousarray(img)
return img, img0
def predict(self, img_path):
img, img0 = self.pre_process(img_path)
img = torch.from_numpy(img).to(self.device)
img = img.half() # uint8 to fp16
img /= 255.0 # 0 - 255 to 0.0 - 1.0
if img.ndimension() == 3:
img = img.unsqueeze(0)
# Inference
pred = self.model(img, augment=False)[0]
return pred, img, img0
def post_process(self, img_path):
pred, img, img0 = self.predict(img_path)
# Apply NMS
pred = non_max_suppression(pred, self.conf, self.iou, classes=None, agnostic=False)
pred, im0 = pred[0], img0
if pred is not None and len(pred):
pred[:, :4] = scale_coords(img.shape[2:], pred[:, :4], im0.shape).round()
pred = pred.cpu().detach().numpy().tolist() # from tensor to list
return pred, img0
def get_image_list(image_dir, suffix=['jpg', 'jpeg', 'JPG', 'JPEG','png']):
'''get all image path ends with suffix'''
if not os.path.exists(image_dir):
print("PATH:%s not exists" % image_dir)
return []
imglist = []
for root, sdirs, files in os.walk(image_dir):
if not files:
continue
for filename in files:
filepath = os.path.join(root, filename)
if filename.split('.')[-1] in suffix:
imglist.append(filepath)
return imglist
def CreatXml(imgPath, results, xmlPath):
img = cv2.imread(imgPath)
imgSize = img.shape
imgName = imgPath.split('/')[-1]
impl = xml.dom.minidom.getDOMImplementation()
dom = impl.createDocument(None, 'annotation', None)
root = dom.documentElement
filename = dom.createElement('filename')
root.appendChild(filename)
name_img = dom.createTextNode(os.path.splitext(imgName)[0])
filename.appendChild(name_img)
folder = dom.createElement('folder')
root.appendChild(folder)
foldername = dom.createTextNode('')
folder.appendChild(foldername)
source = dom.createElement('source')
root.appendChild(source)
database_img = dom.createElement('sourceImage')
img_source = dom.createTextNode('')
database_img.appendChild(img_source)
database_anno = dom.createElement('sourceAnnotation')
database_name = dom.createTextNode('Datumaro')
database_anno.appendChild(database_name)
source.appendChild(database_img)
source.appendChild(database_anno)
img_size = dom.createElement('imagesize')
root.appendChild(img_size)
width = dom.createElement('nrows')
width_num = dom.createTextNode(str(int(imgSize[1])))
width.appendChild(width_num)
height = dom.createElement('ncols')
height_num = dom.createTextNode(str(int(imgSize[0])))
height.appendChild(height_num)
img_size.appendChild(width)
img_size.appendChild(height)
for i in range(len(results)):
img_object = dom.createElement('object')
root.appendChild(img_object)
label = dom.createElement('name')
label_name = dom.createTextNode(results[i]['class'])
label.appendChild(label_name)
dele = dom.createElement('deleted')
dele_name = dom.createTextNode('0')
dele.appendChild(dele_name)
ver = dom.createElement('verified')
ver_name = dom.createTextNode('0')
ver.appendChild(ver_name)
ocl = dom.createElement('occluded')
ocl_name = dom.createTextNode('no')
ocl.appendChild(ocl_name)
date = dom.createElement('date')
date_name = dom.createTextNode('')
date.appendChild(date_name)
face_id = dom.createElement('id')
face_id_name = dom.createTextNode(str(i))
face_id.appendChild(face_id_name)
parts = dom.createElement('parts')
pt_node_1 = dom.createElement('hasparts')
node1_name = dom.createTextNode('')
pt_node_1.appendChild(node1_name)
pt_node_2 = dom.createElement('ispartof')
node2_name = dom.createTextNode('')
pt_node_2.appendChild(node2_name)
parts.appendChild(pt_node_1)
parts.appendChild(pt_node_2)
face_type = dom.createElement('type')
face_type_name = dom.createTextNode('bounding_box')
face_type.appendChild(face_type_name)
bndbox = dom.createElement('polygon')
left_top = dom.createElement('pt')
x_top = dom.createElement('x')
x_top_val = dom.createTextNode(str(int(results[i]['bbox'][0])))
x_top.appendChild(x_top_val)
y_top = dom.createElement('y')
y_top_val = dom.createTextNode(str(int(results[i]['bbox'][1])))
y_top.appendChild(y_top_val)
left_top.appendChild(x_top)
left_top.appendChild(y_top)
right_bottom = dom.createElement('pt')
x_bottom = dom.createElement('x')
x_bottom_val = dom.createTextNode(str(int(results[i]['bbox'][2])))
x_bottom.appendChild(x_bottom_val)
y_bottom = dom.createElement('y')
y_bottom_val = dom.createTextNode(str(int(results[i]['bbox'][3])))
y_bottom.appendChild(y_bottom_val)
right_bottom.appendChild(x_bottom)
right_bottom.appendChild(y_bottom)
user_name = dom.createElement('user_name')
u_name = dom.createTextNode('')
user_name.appendChild(u_name)
bndbox.appendChild(left_top)
bndbox.appendChild(right_bottom)
bndbox.appendChild(user_name)
face_attri = dom.createElement('attributes')
attri_val = dom.createTextNode('')
face_attri.appendChild(attri_val)
img_object.appendChild(label)
img_object.appendChild(dele)
img_object.appendChild(ver)
img_object.appendChild(ocl)
img_object.appendChild(date)
img_object.appendChild(face_id)
img_object.appendChild(parts)
img_object.appendChild(face_type)
img_object.appendChild(bndbox)
img_object.appendChild(face_attri)
f = open(xmlPath, 'w')
dom.writexml(f, addindent=' ', newl='\n')
f.close()
if __name__ == '__main__':
detector = Yolov5Detect()
img_list = get_image_list("imgs")
#img_list = ["test.jpg"]
for img_path in tqdm(img_list):
pred, img0 = detector.post_process(img_path)
if pred is None:
continue
# 筛选出person, person的标签为0
pred = [i for i in pred if i[-1]==0.0]
if pred is None:
continue
objects = []
for obj in pred:
result = {}
x1, y1, x2, y2, conf, label = obj
result['class'] = "行人框(属性)"
result['bbox'] = [int(x1), int(y1), int(x2), int(y2)]
objects.append(result)
CreatXml(img_path, objects, os.path.join("xmls", os.path.basename(img_path).replace(".jpg", ".xml")))
注:代码比较清晰,拿去直接用就行了。