7.1 yolov5优化模型时,自动标注xml数据

yolov5优化模型时,一般需要继续标注一些检测错误的图片,将其标为xml数据。以下是根据训练好的模型自动标注xml数据的python代码:

注意:代码中包含了本人的yolov5的测试过程,测试过程可以自己根据yolov5的测试文件自行修改,只是测试返回的类格式为:

[["water",[15,20,30,40]],["red",[12,13,14,15]]]

二维数组表示测试的类为water和red,其中后面的数字表示类的坐标:[top,left,bottom,right],表示上、左、下、右4个坐标。


import os
import cv2
from PIL import Image

from yolo import YOLO


#1.预测类,获得字符串
class Predict():

    def a(self, img_path,save_path,img_name):

        image = Image.open(img_path)

        r_image, pred = yolo.detect_image(image, pred_class, img_name)

        if not os.path.exists(dir_save_path):
            os.makedirs(dir_save_path)

        r_image.save(save_path, quality=95, subsampling=0)

        return pred


#2.写入xml文件
def img_xml(img_path,xml_path,img_name,pred):

    if len(pred) != 0:

        #1.读取图片(xml需要写入图片的长宽高)
        img = cv2.imread(img_path)

        #2.写入xml文件
        #(1)写入文件头部
        files_path=img_path.split("\\")[-2]
        print("..:",files_path)

        xml_file = open((xml_path + img_name + '.xml'), 'w')
        xml_file.write('\n')
        xml_file.write('	' +files_path+ '\n')
        xml_file.write('	' + img_name + '.jpg' + '\n')
        xml_file.write('	' + img_path +'\n')

        xml_file.write('	\n')
        xml_file.write('		Unknown\n')
        xml_file.write('	\n')

        #(2)写入图片的长宽高信息
        xml_file.write('	\n')
        xml_file.write('		'+str(img.shape[1])+'\n')
        xml_file.write('		' + str(img.shape[0]) + '\n')
        xml_file.write('		' + str(img.shape[2]) + '\n')
        xml_file.write('	\n')

        xml_file.write('	0\n')

        #3.写入字符串信息:[["water",[15,20,30,40]],["red",[12,13,14,15]]]
        #if len(shuzu)!=0:
        for item in pred:
            xml_file.write('	\n')
            xml_file.write('		' + str(item[0]) + '\n')
            xml_file.write('		Unspecified\n')
            xml_file.write('		0\n')
            xml_file.write('		0\n')
            xml_file.write('		\n')

            #写入字符串信息
            #[top, left, bottom, right]
            xml_file.write('			' + str(item[1][1]) + '\n')
            xml_file.write('			' + str(item[1][0]) + '\n')
            xml_file.write('			' + str(item[1][3]) + '\n')
            xml_file.write('			' + str(item[1][2]) + '\n')

            xml_file.write('		\n')
            xml_file.write('	\n')

        xml_file.write('\n')





if __name__ == "__main__":
    yolo = YOLO()
    ss = Predict()

    #需要修改以下4个量,并且要去VOCdevkit/VOC2007/文件夹下替换训练好的模型best_epoch_weights.pth和voc_classes.txt

    pred_class = ["car", "moto", "persons"]  # 填入需要检测的类名
    file_path = r"D:\AI\4.yolov5-pytorch-main_xml_write\save\image"  # 填入测试的图片路径
    dir_save_path = r"D:\AI\4.yolov5-pytorch-main_xml_write\save\image_save"# 填入保存的图片路径
    xml_path="save\\xml_save\\"# 填入保存的xml文件的路径

    ls=os.listdir(file_path)
    for item in ls:
        img_name=item
        xml_name=img_name.split(".")[0]+".xml"
        img_names=img_name.split(".")[0]

        img_path=os.path.join(file_path,img_name)
        save_path=os.path.join(dir_save_path,img_name)
        #xml_path=os.path.join(xml_path,xml_name)

        pred=ss.a(img_path,save_path,img_name)

        img_xml(img_path, xml_path, img_names, pred)

你可能感兴趣的:(7.数据处理,YOLO,xml,深度学习)