PASCAL VOC是一个国际的计算机视觉挑战赛,数据集包含了20个分类的3万多张图片。挑战赛及其数据集基础上涌现不少知名的目标检测模型如R-CNN,YOLO,SSD等。可以通过下载和读取的方法载入PASCAL VOC数据集。
1 数据集下载
PASCAL VOC数据集可以从官方网站下载
http://host.robots.ox.ac.uk/pascal/VOC/
常用的是2007版和2012版,下载后是一个tar文件包
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
从外网下载速度较慢,如果用迅雷等下载,会自动从国内镜像下载,可以取得较快的速度
下载之后解压,出现VOCdevkit目录,里面包含VOC2012目录,下面包含Annotations,JPEGImages等目录
VOCdevkit #根目录
VOC2012 #2012版
Annotations #标注
ImageSets #分类信息
JPEGImages #jpeg图像目录
SegmentationClass #类别分割信息
SegmentationObject #物体分割信息
2 数据集读取
首先载入需要的模块
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import xml.dom.minidom as minidom
xml模块用来读取Annotations目录中的xml文件,VOC数据集的标注用xml文件来表示。
定义数据目录变量如下
data_path = '../data/PASCAL_VOC/VOCdevkit/VOC2012'
anno_path = f'{data_path}/Annotations'
image_path = f'{data_path}/JPEGImages'
然后从Annotations目录读取一个图片信息的列表,这里读取100张
# collect 100 samples from files in anno_path
sample_list = []
for xml in os.listdir(annotation_path):
file_id = os.path.splitext(xml)[0]
xml_file = f'{anno_path}/{file_id}.xml'
jpg_file = f'{image_path}/{file_id}.jpg'
sample = {'id':file_id, 'xml_file':xml_file, 'jpg_file':jpg_file}
sample_list.append(sample)
if len(sample_list) >= 100:
break
生成一个列表,每一个元素包含图片的id,对应的xml标注文件和图像文件(jpeg)。
然后定义一个从xml文件读取目标信息的函数。xml的文件格式大致如下
-<annotation>
<folder>VOC2012folder>
<filename>2007_000333.jpgfilename>
-<source>
<database>The VOC2007 Databasedatabase>
<annotation>PASCAL VOC2007annotation>
<image>flickrimage>
source>
-<size>
<width>500width>
<height>333height>
<depth>3depth>
size>
<segmented>1segmented>
-<object>
<name>trainname>
<pose>Unspecifiedpose>
<truncated>1truncated>
<difficult>0difficult>
-<bndbox>
<xmin>1xmin>
<ymin>39ymin>
<xmax>367xmax>
<ymax>270ymax>
bndbox>
object>
annotation>
其中包含
# get information of objects from xml file
def get_objects(xml_file):
dom = minidom.parse(xml_file)
anno = dom.documentElement
objects = anno.getElementsByTagName('object')
obj_names = []
for obj in objects:
obj_name = obj.getElementsByTagName('name')[0].childNodes[0].data
obj_names.append(obj_name)
return obj_names
即从根节点出发,寻找object子节点,遍历所有子节点,把对应的目标名称加入到object_names,最后返回这张图片所包含的所有目标的名称。
然后,显示数据集中的图片和对应的分类目标信息,限于屏幕大小,这里显示最前面的25张图片
# lines and columns of subplots
m = 5
n = 5
num = m*n
# size of figure
plt.figure(figsize=(14,13))
# plot first 25 pictures
for i in range(num):
plt.subplot(m,n,i+1)
objects = get_objects(sample_list[i]['xml_file'])
img = mpimg.imread(sample_list[i]['jpg_file'])
plt.imshow(img)
plt.xticks([])
plt.yticks([])
plt.xlabel(','.join(objects))
plt.show()