抽取VOC数据集中的部分类别数据

# -*- encoding: utf-8 -*-

import os
import xml.etree.ElementTree as ET
import shutil

# VOC2007是原VOC2007数据集,VOC2007_8是我创建的用于存放8个类别的文件夹
ann_filepath = 'VOCdevkit/VOC2007/Annotations/'
img_filepath = 'VOCdevkit/VOC2007/JPEGImages/'
img_savepath = 'VOCdevkit/VOC2007_8/JPEGImages/'
ann_savepath = 'VOCdevkit/VOC2007_8/Annotations/'
if not os.path.exists(img_savepath):
    os.mkdir(img_savepath)

if not os.path.exists(ann_savepath):
    os.mkdir(ann_savepath)

# 这里我抽取了其中的8个类别
classes = ['aeroplane', 'bus', 'cat', 'dog', 'horse', 'motorbike', 'person', 'train']


def save_annotation(file):

    tree = ET.parse(ann_filepath + '/' + file)
    root = tree.getroot()
    result = root.findall("object")
    bool_num = 0
    for obj in result:
        if obj.find("name").text not in classes:
            root.remove(obj)
        else:
            bool_num = 1
    if bool_num:
        tree.write(ann_savepath + file)
        return True
    else:
        return False


def save_images(file):
    name_img = img_filepath + os.path.splitext(file)[0] + ".jpg"
    shutil.copy(name_img, img_savepath)
    return True


if __name__ == '__main__':
    for f in os.listdir(ann_filepath):
        if save_annotation(f):
            save_images(f)

 可以通过如下代码检验所抽取的文件是否是自己选择的类别:

import os
import xml.dom.minidom

# Path of xml files folder
xml_path = 'VOCdevkit/VOC2007_8/Annotations/'
files = os.listdir(xml_path)

gt_dict = {}

if __name__ == '__main__':

    for xm in files:
        xmlfile = xml_path + xm
        dom = xml.dom.minidom.parse(xmlfile)
        root = dom.documentElement
        filenamelist = root.getElementsByTagName("filename")
        filename = filenamelist[0].childNodes[0].data
        objectlist = root.getElementsByTagName("object")
        ##
        for objects in objectlist:
            namelist = objects.getElementsByTagName("name")
            objectname = namelist[0].childNodes[0].data
            if objectname == '-':
                print(filename)
            if objectname in gt_dict:
                gt_dict[objectname] += 1
            else:
                gt_dict[objectname] = 1

    dic = sorted(gt_dict.items(), key=lambda d: d[1], reverse=True)
    print(dic)
    print(len(dic))

打印输出结果:

[('person', 5227), ('dog', 530), ('horse', 395), ('cat', 370), ('motorbike', 369), ('aeroplane', 311), ('train', 302), ('bus', 254)]
8

参考:https://www.cnblogs.com/dan-baishucaizi/p/11911810.html 

你可能感兴趣的:(Python代码,深度学习,目标检测,深度学习,python,目标检测)