目录
VOC2012数据集
图片类别
xml文件的读取
code
链接:https://pan.baidu.com/s/1uV5j6BEkwd8yKLUhaUPzPQ?pwd=aaaa
提取码:aaaa
数据集目录:(共包含10张图片)
其中Annotations为10张图片的label(xml文件), ImageSets-main中的txt文档为10张图片的名字,JPEGImages为10张图片。
21个类别,类别名见CLASS_NAME,通过zip函数将类别名编号,分别对应序号0-20,转化为字典形式。
"""类别字典的创建 class_name:序号 """
CLASSES_NAME = (
"__background__ ",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
)
name2id =dict(zip(CLASSES_NAME,range(len(CLASSES_NAME))))
xml文件只能从根节点往下一步一步地遍历。ET为元素树方法,ET.parse读取label,通过.getroot()获取根节点anno。
tag:标签,用于标识该元素表示哪种数据
attrib:属性,用字典形式保存
text:文本字符串,通过 .find(节点).text查看节点的内容
import xml.etree.ElementTree as ET
import os
import numpy as np
"""类别字典的创建 class_name:序号 """
CLASSES_NAME = (
"__background__ ",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
)
name2id =dict(zip(CLASSES_NAME,range(len(CLASSES_NAME))))
def get_xml_label(label_path):
"""从xml文件中获得label"""
anno = ET.parse(label_path).getroot() # .getroot()获取根节点
# for node in anno: # 子树
# print(node.tag,node.attrib) # 节点名称以及节点属性(含object物体)
boxes = []
classes = []
for obj in anno.iter("object"): # 迭代object的子节点
# for i in obj:
# print(i) # object的子节点含:name pose truncated occluded bndbox difficult
# 放弃难分辨的图片
difficult = int(obj.find("difficult").text) == 1
if difficult:
continue
# bounding box坐标值的查找
_box = obj.find("bndbox")
box = [
_box.find("xmin").text,
_box.find("ymin").text,
_box.find("xmax").text,
_box.find("ymax").text,
]
# 框像素点位置-1(python从0开始)
TO_REMOVE = 1
box = tuple(
map(lambda x: x - TO_REMOVE, list(map(float, box)))
)
boxes.append(box)
# 框对应的类别序号
name = obj.find("name").text.lower().strip() # 类别名称,统一为小写,并且去除左右空格以及换行符
classes.append(name2id[name]) # 序号
boxes = np.array(boxes, dtype=np.float32)
return boxes, classes
label_path=os.path.join(r'D:\VOC2012\Annotations','%s.xml') # %s指待输入的字符串
boxes,classes=get_xml_label(label_path %'2008_000007')
print(boxes)
print(classes)