Pascal voc.py
#pascal_voc这个类主要用来组织输入的图片数据,存储图片的相关信息,但并不存储图片;而实际上,pascal_voc类是imdb类的一个子类
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
#该类继承于IMDB,用于负责数据交互部分
import os
from datasets.imdb import imdb
import datasets.ds_utils as ds_utils
import xml.etree.ElementTree as ET #ElementTree表示整个XML树,Element表示树上的单个节点
import numpy as np
import scipy.sparse
import scipy.io as sio
import utils.cython_bbox
import cPickle
import subprocess
import uuid
from voc_eval import voc_eval
from fast_rcnn.config import cfg
class pascal_voc(imdb):
def __init__(self, image_set, year, devkit_path=None):
#这个类用来组织输入的图片数据,但并没有将真实的图片存进去
# 传进来的第一个参数为数据集名称(train,val,test),第
# 二个参数为版本,如2007,
# devkit_path暂时为空
imdb.__init__(self, 'voc_' + year + '_' + image_set)
#调用IMDB的构造函数,传入参数格式为‘voc_year_imageset'--例如voc_2007_train,
# 其实就是记录了一下self._name,其余的为默认
#其余默认参数有(self._num_classes,self._classes,self._image_index,self._obj_proposer,
# self._roidb_handler,self.config)
self._year = year
self._image_set = image_set
self._devkit_path = self._get_default_path() if devkit_path is None \
else devkit_path
#devkit_path在不设定为none,此时self._devkit_path为py_faster_rcnn/data/ VOCdevkit+self._year
self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
# self._data_path 为data / VOCdevkit +'year'/'VOC' + self._year
self._classes = ('__background__','ant', 'butterfly','cicadas','dragonfly','ladybug','mantis','honeybee','fly','grasshopper','cricket')
#self._class_to_ind里存的是{'__background__':0,'aeroplane':1.....}
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) #给每一个类别分别赋予一个整数
self._image_ext = '.jpg' #图片的扩展名
self._image_index = self._load_image_set_index() #把所有的图片名称加载,放在list中,便于索引读取图片
# Default to roidb handler
# 得到ROI图片信息,重载IMDB中
self._roidb_handler = self.selective_search_roidb
self._salt = str(uuid.uuid4()) # UUID是128位的全局唯一标识符,通常由32字节的字符串表示。它可以保证时间和空间的唯一性,也称为GUID(C#)。它通过MAC地址、时间戳、命名空间、随机数、伪随机数来保证生成ID的唯一性。uuid4()——基于随机数;由伪随机数得到,有一定的重复概率,该概率可以计算出来。
self._comp_id = 'comp4'
# PASCAL specific config options PASCAL特殊配置选项
self.config = {'cleanup' : True,
'use_salt' : True,
'use_diff' : False,
'matlab_eval' : False,
'rpn_file' : None,
'min_size' : 2}
assert os.path.exists(self._devkit_path), \
'VOCdevkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)
def image_path_at(self, i):#重载了imdb.py中定义,返回图片所在全路径
"""
Return the absolute path to image i in the image sequence.
"""
return self.image_path_from_index(self._image_index[i])
def image_path_from_index(self, index):
"""
Construct an image path from the image's "index" identifier标识符.
"""
image_path = os.path.join(self._data_path, 'JPEGImages',
index + self._image_ext)#例如/home/ubuntu/py-faster-rcnn/data/VOCdevkit2007/VOC2007/JPEGImages/0.jpg
assert os.path.exists(image_path), \
'Path does not exist: {}'.format(image_path)
return image_path
def _load_image_set_index(self):
"""
Load the indexes listed in this dataset's image set file.
"""
#获取图片索引
# Example path to image set file:
# self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
self._image_set + '.txt')
assert os.path.exists(image_set_file), \
'Path does not exist: {}'.format(image_set_file)
with open(image_set_file) as f:
image_index = [x.strip() for x in f.readlines()]#移除字符串头尾指定的字符(默认为空格或换行符)或字符序列
return image_index
#返回的image_index为一个列表,包含该数据集图片名称信息(之前做VOC数据集时候就有在对应txt中,是没有.jpg后缀的,这是为了让你方便修改代码,制作自己的数据集)
def _get_default_path(self):
"""
Return the default path where PASCAL VOC is expected to be installed.返回Pascalvoc的路径
"""
return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)
def gt_roidb(self): #获取ground_truth的roidb格式
"""
Return the database of ground-truth regions of interest.得到ROI组成database
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') #保存缓存文件,目录data/cache/voc_2007_train_gt_roidb.pkl
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:#打开本地文件,并指定以二进制格式打开一个文件用于只读。文件指针将会放在文件的开头。这是默认模式。一般用于非文本文件如图片等。
roidb = cPickle.load(fid) #加载fid到roidb中
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb
gt_roidb = [self._load_pascal_annotation(index) #使用_load_pascal_annotation()从XML中解析gt_roidb数据
for index in self.image_index]
with open(cache_file, 'wb') as fid: #打开本地文件,并指定以二进制格式打开一个文件只用于写入。如果该文件已存在则将其覆盖。如果该文件不存在,创建新文件。一般用于非文本文件如图片等。
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) #将roidb数据序列化保存到cache_file中(将python对象序列化保存到本地的文件)。一种高效的加载方式cPickle.HIGHEST_PROTOCOL,可使得节省80%空间。dump函数需要指定两个参数,第一个是需要序列化的python对象名称,第二个是本地的文件。cPickle.HIGHEST_PROTOCOL:An integer, the highest protocol version available. This value can be passed as a protocol value to functions dump() and dumps() as well as the Picklerconstructor.也可以使用-1表示。
print 'wrote gt roidb to {}'.format(cache_file)
return gt_roidb
def selective_search_roidb(self):
"""
Return the database of selective search regions of interest.返回候选ROI数据集
Ground-truth ROIs are also included.
This function loads/saves from/to a cache file to speed up future calls.
没有RPN的fast-rcnn提取候选框的方式。返回的是提取出来的ROI以及图片的gt。
这个函数在Faster-RCNN里面用不到,在fast-rcnn里面才会用到
"""
cache_file = os.path.join(self.cache_path,
self.name + '_selective_search_roidb.pkl')
if os.path.exists(cache_file):#判断这个.pkl是否存在,如果存在调用该文件内容
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)#将selective_search_roidb.pkl加载到roidb中
print '{} ss roidb loaded from {}'.format(self.name, cache_file)
return roidb
if int(self._year) == 2007 or self._image_set != 'test':
gt_roidb = self.gt_roidb()
ss_roidb = self._load_selective_search_roidb(gt_roidb)
roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)
else:
roidb = self._load_selective_search_roidb(None)
with open(cache_file, 'wb') as fid:
cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote ss roidb to {}'.format(cache_file)
return roidb
def rpn_roidb(self):#使用rpn_roidb这种method从imdb中roidb数据
if int(self._year) == 2007 or self._image_set != 'test':#数据集名称包含2017年且不是test数据集可以用RPN_roidb这种method
gt_roidb = self.gt_roidb()#使用gt_roidb获得ground truth的roidb,其实就是从XML中解析得到的
rpn_roidb = self._load_rpn_roidb(gt_roidb)#生成rpn_roidb
roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)#将rpn_roidb和gt_roidb组合
else:
roidb = self._load_rpn_roidb(None)
return roidb
def _load_rpn_roidb(self, gt_roidb):#获得RPN产生的roidb
filename = self.config['rpn_file']#获取RPN——file文件,它就是RPN网络产生的proposal
print 'loading {}'.format(filename)
assert os.path.exists(filename), \
'rpn data not found at: {}'.format(filename)
with open(filename, 'rb') as f:
box_list = cPickle.load(f)#将RPN网络产生proposal的文件加载到box_list文件中
return self.create_roidb_from_box_list(box_list, gt_roidb) #调用超类imdb的这个函数产生roidb
def _load_selective_search_roidb(self, gt_roidb):
filename = os.path.abspath(os.path.join(cfg.DATA_DIR,
'selective_search_data',
self.name + '.mat'))
assert os.path.exists(filename), \
'Selective search data not found at: {}'.format(filename)
raw_data = sio.loadmat(filename)['boxes'].ravel()
box_list = []
for i in xrange(raw_data.shape[0]):
boxes = raw_data[i][:, (1, 0, 3, 2)] - 1
keep = ds_utils.unique_boxes(boxes)
boxes = boxes[keep, :]
keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])
boxes = boxes[keep, :]
box_list.append(boxes)
return self.create_roidb_from_box_list(box_list, gt_roidb)
def _load_pascal_annotation(self, index): #解析XML中的信息
"""
Load image and bounding boxes info from XML file in the PASCAL VOC
format.
"""
filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
tree = ET.parse(filename) #从硬盘导入XML文件,将XML文档解析为ElementTree对象。
objs = tree.findall('object') #找到所有属于某个tag的element
if not self.config['use_diff']:
# Exclude the samples labeled as difficult排除标记为困难的样本
non_diff_objs = [
obj for obj in objs if int(obj.find('difficult').text) == 0] #寻找‘difficult’tag中的值为0的obj
# if len(non_diff_objs) != len(objs):
# print 'Removed {} difficult objects'.format(
# len(objs) - len(non_diff_objs))
objs = non_diff_objs
num_objs = len(objs)
boxes = np.zeros((num_objs, 4), dtype=np.uint16) #boxes存储坐标,num_objs(样本物体个数) x4
gt_classes = np.zeros((num_objs), dtype=np.int32) #gt_classes 存储要分的类别,这里的类别数等于num_objs数
overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) #存储重叠率的矩阵,num_obj x num_classes(样本数x类别数)
# "Seg" area for pascal is just the box area
seg_areas = np.zeros((num_objs), dtype=np.float32) #h候选框的面积,个数就是Box的个数
# Load object bounding boxes into a data frame.
#对每一个objs中的obj进行操作
for ix, obj in enumerate(objs): #ix是索引
bbox = obj.find('bndbox')
# Make pixel indexes 0-based
x1 = float(bbox.find('xmin').text) - 1 #获取gt的坐标信息
y1 = float(bbox.find('ymin').text) - 1
x2 = float(bbox.find('xmax').text) - 1
y2 = float(bbox.find('ymax').text) - 1
cls = self._class_to_ind[obj.find('name').text.lower().strip()] #取出当前obj的name,变小写,去除字符串头尾 '/n','/t',' ',然后取出对应字典中的引索值,如aeroplane的cls为1
boxes[ix, :] = [x1, y1, x2, y2] #将坐标信息存储到boxes列表中,e.g. boxes为([猫1的四个bbox值],[猫2的四个bbox值]..)
gt_classes[ix] = cls#e.g. (1,7,10,4..)
overlaps[ix, cls] = 1.0 #因为这里的Box就是gt,所以重叠率设置为1;这样子其实overlaps就成了一个单位矩阵。e.g. 生成类似与one-hot编码[[0,0,0,0,1,0,0,0,][0,0,0,0,1,0,0,0,]]
seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) #gt的面积
overlaps = scipy.sparse.csr_matrix(overlaps)#将overlaps稀疏矩阵压缩
#总结类型:以下key的类型依次为array、array、scipy.sparse.csr.csr_matrix、bool、array
return {'boxes' : boxes, #返回5个key,boxes存储坐标
'gt_classes': gt_classes, #存储每个Box对应的类索引
'gt_overlaps' : overlaps, #共有num_classes(类的个数)行,每一行对应的Box的类索引值为1,其他皆为0,后来被转化为稀疏矩阵;
'flipped' : False, #表示该图片还未被翻转。
'seg_areas' : seg_areas}
#roidb的结构是一个包含有5个key的字典
def _get_comp_id(self):# self._comp_id = 'comp4'
comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
else self._comp_id)
return comp_id
def _get_voc_results_file_template(self):
# VOCdevkit/results/VOC2007/Main/_det_test_aeroplane.txt
filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
path = os.path.join(
self._devkit_path,
'results',
'VOC' + self._year,
'Main',
filename)
return path
def _write_voc_results_file(self, all_boxes):
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__':
continue
print 'Writing {} VOC results file'.format(cls)
filename = self._get_voc_results_file_template().format(cls)
with open(filename, 'wt') as f:
for im_ind, index in enumerate(self.image_index):
dets = all_boxes[cls_ind][im_ind]
if dets == []:
continue
# the VOCdevkit expects 1-based indices
for k in xrange(dets.shape[0]):
f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
format(index, dets[k, -1],
dets[k, 0] + 1, dets[k, 1] + 1,
dets[k, 2] + 1, dets[k, 3] + 1))
def _do_python_eval(self, output_dir = 'output'):
annopath = os.path.join(
self._devkit_path,
'VOC' + self._year,
'Annotations',
'{:s}.xml')
imagesetfile = os.path.join(
self._devkit_path,
'VOC' + self._year,
'ImageSets',
'Main',
self._image_set + '.txt')
cachedir = os.path.join(self._devkit_path, 'annotations_cache')
aps = []
# The PASCAL VOC metric changed in 2010
use_07_metric = True if int(self._year) < 2010 else False
print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No')
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
for i, cls in enumerate(self._classes):
if cls == '__background__':
continue
filename = self._get_voc_results_file_template().format(cls)
rec, prec, ap = voc_eval(
filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
use_07_metric=use_07_metric)
aps += [ap]
print('AP for {} = {:.4f}'.format(cls, ap))
with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
cPickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
print('Mean AP = {:.4f}'.format(np.mean(aps)))
print('~~~~~~~~')
print('Results:')
for ap in aps:
print('{:.3f}'.format(ap))
print('{:.3f}'.format(np.mean(aps)))
print('~~~~~~~~')
print('')
print('--------------------------------------------------------------')
print('Results computed with the **unofficial** Python eval code.')
print('Results should be very close to the official MATLAB eval code.')
print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
print('-- Thanks, The Management')
print('--------------------------------------------------------------')
def _do_matlab_eval(self, output_dir='output'):
print '-----------------------------------------------------'
print 'Computing results with the official MATLAB eval code.'
print '-----------------------------------------------------'
path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
'VOCdevkit-matlab-wrapper')
cmd = 'cd {} && '.format(path)
cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
cmd += '-r "dbstop if error; '
cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \
.format(self._devkit_path, self._get_comp_id(),
self._image_set, output_dir)
print('Running:\n{}'.format(cmd))
status = subprocess.call(cmd, shell=True)
def evaluate_detections(self, all_boxes, output_dir):
self._write_voc_results_file(all_boxes)
self._do_python_eval(output_dir)
if self.config['matlab_eval']:
self._do_matlab_eval(output_dir)
if self.config['cleanup']:
for cls in self._classes:
if cls == '__background__':
continue
filename = self._get_voc_results_file_template().format(cls)
os.remove(filename)
def competition_mode(self, on):
if on:
self.config['use_salt'] = False
self.config['cleanup'] = False
else:
self.config['use_salt'] = True
self.config['cleanup'] = True
if __name__ == '__main__':
from datasets.pascal_voc import pascal_voc
d = pascal_voc('trainval', '2007')
res = d.roidb
from IPython import embed; embed()