yolov5优化模型时,一般需要继续标注一些检测错误的图片,将其标为xml数据。以下是根据训练好的模型自动标注xml数据的python代码:
注意:代码中包含了本人的yolov5的测试过程,测试过程可以自己根据yolov5的测试文件自行修改,只是测试返回的类格式为:
[["water",[15,20,30,40]],["red",[12,13,14,15]]]
二维数组表示测试的类为water和red,其中后面的数字表示类的坐标:[top,left,bottom,right],表示上、左、下、右4个坐标。
import os
import cv2
from PIL import Image
from yolo import YOLO
#1.预测类,获得字符串
class Predict():
def a(self, img_path,save_path,img_name):
image = Image.open(img_path)
r_image, pred = yolo.detect_image(image, pred_class, img_name)
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
r_image.save(save_path, quality=95, subsampling=0)
return pred
#2.写入xml文件
def img_xml(img_path,xml_path,img_name,pred):
if len(pred) != 0:
#1.读取图片(xml需要写入图片的长宽高)
img = cv2.imread(img_path)
#2.写入xml文件
#(1)写入文件头部
files_path=img_path.split("\\")[-2]
print("..:",files_path)
xml_file = open((xml_path + img_name + '.xml'), 'w')
xml_file.write('\n')
xml_file.write(' ' +files_path+ ' \n')
xml_file.write(' ' + img_name + '.jpg' + ' \n')
xml_file.write(' ' + img_path +' \n')
xml_file.write(' \n')
#(2)写入图片的长宽高信息
xml_file.write(' \n')
xml_file.write(' '+str(img.shape[1])+' \n')
xml_file.write(' ' + str(img.shape[0]) + ' \n')
xml_file.write(' ' + str(img.shape[2]) + ' \n')
xml_file.write(' \n')
xml_file.write(' 0 \n')
#3.写入字符串信息:[["water",[15,20,30,40]],["red",[12,13,14,15]]]
#if len(shuzu)!=0:
for item in pred:
xml_file.write(' \n')
xml_file.write(' \n')
if __name__ == "__main__":
yolo = YOLO()
ss = Predict()
#需要修改以下4个量,并且要去VOCdevkit/VOC2007/文件夹下替换训练好的模型best_epoch_weights.pth和voc_classes.txt
pred_class = ["car", "moto", "persons"] # 填入需要检测的类名
file_path = r"D:\AI\4.yolov5-pytorch-main_xml_write\save\image" # 填入测试的图片路径
dir_save_path = r"D:\AI\4.yolov5-pytorch-main_xml_write\save\image_save"# 填入保存的图片路径
xml_path="save\\xml_save\\"# 填入保存的xml文件的路径
ls=os.listdir(file_path)
for item in ls:
img_name=item
xml_name=img_name.split(".")[0]+".xml"
img_names=img_name.split(".")[0]
img_path=os.path.join(file_path,img_name)
save_path=os.path.join(dir_save_path,img_name)
#xml_path=os.path.join(xml_path,xml_name)
pred=ss.a(img_path,save_path,img_name)
img_xml(img_path, xml_path, img_names, pred)