目标检测 split xml to patch

# -*- encoding: utf-8 -*-
# Author  : haitong
# Time    : 2023/6/8 9:20
# File    : split_to_patch.py
# Software: PyCharm

import os
import cv2
import numpy
import xml.etree.ElementTree as ET
from tqdm import *
from PIL import Image


def make_file(path):
    if not os.path.exists(path):
        os.makedirs(path)


if __name__ == "__main__":
    dir_path = os.path.dirname(os.path.abspath(__file__))
    data_file_dir = os.path.join(dir_path, 'data', 'jpg')    # 原始图片路径
    xml_file_dir = os.path.join(dir_path, 'data', 'xml')  	 # xml路径
    out_file_dir = os.path.join(dir_path, 'data', 'patch')    # patch路径
    make_file(out_file_dir)

    all_number = 0
    for img_name in tqdm(sorted(os.listdir(data_file_dir))):
        if img_name.endswith('.jpg') or img_name.endswith('.png'):
            # print(img_name)
            img_path = os.path.join(data_file_dir, img_name)

            # img_bgr = cv2.imread(img_path)
            img_bgr = Image.open(img_path)
            img_bgr = cv2.cvtColor(numpy.asarray(img_bgr), cv2.COLOR_RGB2BGR)

            xml_name = img_name[:-3] + "xml"
            xml_path = os.path.join(xml_file_dir, xml_name)
            # print(xml_path)
            xml_inf = open(xml_path, encoding='utf-8')
            tree = ET.parse(xml_inf)
            root = tree.getroot()
            index = 1

            for obj in root.iter('object'):
                
                bbox_label = obj.find('name').text
             
                if bbox_label == "jxc":
                    bbox_top_left_x = int(obj.find('bndbox').find('xmin').text)
                    bbox_top_left_y = int(obj.find('bndbox').find('ymin').text)
                    bbox_bottom_right_x = int(obj.find('bndbox').find('xmax').text)
                    bbox_bottom_right_y = int(obj.find('bndbox').find('ymax').text)
                    img_ = img_bgr.copy()
                    cropped = img_[bbox_top_left_y:bbox_bottom_right_y, bbox_top_left_x:bbox_bottom_right_x, :]  # 裁剪坐标为[y0:y1, x0:x1]
                    if (bbox_bottom_right_y-bbox_top_left_y) * (bbox_bottom_right_x-bbox_top_left_x) <= 1600:
                        print("Error:", img_path, "\t", bbox_top_left_y, bbox_bottom_right_y, bbox_top_left_x, bbox_bottom_right_x)
                        continue
                    cv2.imwrite(os.path.join(out_file_dir, "{}+{}_{}_{}_{}_{}_{}.jpg".format(img_name[:-4], bbox_label, bbox_top_left_x, bbox_top_left_y, bbox_bottom_right_x, bbox_bottom_right_y, index)), cropped)

                    index += 1
                    all_number += 1

        # break

    print("all_number: ", all_number)

你可能感兴趣的:(目标检测,xml,python)