Pytorch-Faster-RCNN 中的 MAP 实现 (解析imdb.py 和 pascal_voc.py)

---恢复内容开始---

MAP是衡量object dectection算法的重要criteria,然而一直没有仔细阅读相关代码,今天就好好看一下:

1. 测试test过程是由FRCN/tools/test_net.py中调用的test_net()完成 #from model.test import test_net

test_net()定义在FRCN/lib/model/test.py (193-194行):调用了imdb.evaluate_detections

print('Evaluating detections')
imdb.evaluate_detections(all_boxes, output_dir)

imdb是从FRCN/lib/model/test.py(84行)传入的:

imdb = get_imdb(args.imdb_name)

from datasets.factory import get_imdb,为了了解如何定义一个imdb,我们去FRCN/lib/datasets/factory.py

 1 """Factory method for easily getting imdbs by name."""
 2 from __future__ import absolute_import
 3 from __future__ import division
 4 from __future__ import print_function
 5 
 6 __sets = {}
 7 from datasets.pascal_voc import pascal_voc
 8 
 9 import numpy as np
10 
11 # Set up voc__ 
12 for year in ['2007', '2012']:
13   for split in ['train', 'val', 'trainval', 'test']:
14     name = 'voc_{}_{}'.format(year, split)
15     __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
16 
17 for year in ['2007', '2012']:
18   for split in ['train', 'val', 'trainval', 'test']:
19     name = 'voc_{}_{}_diff'.format(year, split)
20     __sets[name] = (lambda split=split, year=year: pascal_voc(split, year, use_diff=True))
21 
22 def get_imdb(name):
23   """Get an imdb (image database) by name."""
24   if name not in __sets:
25     raise KeyError('Unknown dataset: {}'.format(name))
26   return __sets[name]()
27 
28 def list_imdbs():
29   """List all registered imdbs."""
30   return list(__sets.keys())

coco数据集的定义同pascal_voc. 可以看到,get_imdb(args.imdb_name)将会返回的就是pascal_voc(split, year)这样一个对象。

 

2. 来到pascal_voc.py :

  1 # --------------------------------------------------------
  2 # Fast R-CNN
  3 # Copyright (c) 2015 Microsoft
  4 # Licensed under The MIT License [see LICENSE for details]
  5 # Written by Ross Girshick and Xinlei Chen
  6 # --------------------------------------------------------
  7 from __future__ import absolute_import
  8 from __future__ import division
  9 from __future__ import print_function
 10 
 11 import os
 12 from datasets.imdb import imdb
 13 import datasets.ds_utils as ds_utils
 14 import xml.etree.ElementTree as ET
 15 import numpy as np
 16 import scipy.sparse
 17 import scipy.io as sio
 18 import pickle
 19 import subprocess
 20 import uuid
 21 from .voc_eval import voc_eval
 22 from model.config import cfg
 23 
 24 
 25 class pascal_voc(imdb):
 26   def __init__(self, image_set, year, use_diff=False):
 27     name = 'voc_' + year + '_' + image_set
 28     if use_diff:
 29       name += '_diff'
 30     imdb.__init__(self, name)
 31     self._year = year
 32     self._image_set = image_set
 33     self._devkit_path = self._get_default_path()
 34     self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
 35     self._classes = ('__background__',  # always index 0
 36                      'title', 'xlabel',  'ylabel')
 37                  ####    'text', 'ylabel')
 38                  #    'aeroplane', 'bicycle', 'bird', 'boat',
 39                  #    'bottle', 'bus', 'car', 'cat', 'chair',
 40                  #    'cow', 'diningtable', 'dog', 'horse',
 41                  #    'motorbike', 'person', 'pottedplant',
 42                  #    'sheep', 'sofa', 'train', 'tvmonitor')
 43     self._class_to_ind = dict(list(zip(self.classes, list(range(self.num_classes)))))
 44     self._image_ext = '.jpg'
 45     self._image_index = self._load_image_set_index()
 46     # Default to roidb handler
 47     self._roidb_handler = self.gt_roidb
 48     self._salt = str(uuid.uuid4())
 49     self._comp_id = 'comp4'
 50 
 51     # PASCAL specific config options
 52     self.config = {'cleanup': True,
 53                    'use_salt': True,
 54                    'use_diff': use_diff,
 55                    'matlab_eval': False,
 56                    'rpn_file': None}
 57 
 58     assert os.path.exists(self._devkit_path), \
 59       'VOCdevkit path does not exist: {}'.format(self._devkit_path)
 60     assert os.path.exists(self._data_path), \
 61       'Path does not exist: {}'.format(self._data_path)
 62 
 63   def image_path_at(self, i):
 64     """
 65     Return the absolute path to image i in the image sequence.
 66     """
 67     return self.image_path_from_index(self._image_index[i])
 68 
 69   def image_path_from_index(self, index):
 70     """
 71     Construct an image path from the image's "index" identifier.
 72     """
 73     image_path = os.path.join(self._data_path, 'JPEGImages',
 74                               index + self._image_ext)
 75     assert os.path.exists(image_path), \
 76       'Path does not exist: {}'.format(image_path)
 77     return image_path
 78 
 79   def _load_image_set_index(self):
 80     """
 81     Load the indexes listed in this dataset's image set file.
 82     """
 83     # Example path to image set file:
 84     # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
 85     image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
 86                                   self._image_set + '.txt')
 87     assert os.path.exists(image_set_file), \
 88       'Path does not exist: {}'.format(image_set_file)
 89     with open(image_set_file) as f:
 90       image_index = [x.strip() for x in f.readlines()]
 91     return image_index
 92 
 93   def _get_default_path(self):
 94     """
 95     Return the default path where PASCAL VOC is expected to be installed.
 96     """
 97     return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)
 98 
 99   def gt_roidb(self):
100     """
101     Return the database of ground-truth regions of interest.
102 
103     This function loads/saves from/to a cache file to speed up future calls.
104     """
105     cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
106     if os.path.exists(cache_file):
107       with open(cache_file, 'rb') as fid:
108         try:
109           roidb = pickle.load(fid)
110         except:
111           roidb = pickle.load(fid, encoding='bytes')
112       print('{} gt roidb loaded from {}'.format(self.name, cache_file))
113       return roidb
114 
115     gt_roidb = [self._load_pascal_annotation(index)
116                 for index in self.image_index]
117     with open(cache_file, 'wb') as fid:
118       pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
119     print('wrote gt roidb to {}'.format(cache_file))
120 
121     return gt_roidb
122 
123   def rpn_roidb(self):
124     if int(self._year) == 2007 or self._image_set != 'test':
125       gt_roidb = self.gt_roidb()
126       rpn_roidb = self._load_rpn_roidb(gt_roidb)
127       roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
128     else:
129       roidb = self._load_rpn_roidb(None)
130 
131     return roidb
132 
133   def _load_rpn_roidb(self, gt_roidb):
134     filename = self.config['rpn_file']
135     print('loading {}'.format(filename))
136     assert os.path.exists(filename), \
137       'rpn data not found at: {}'.format(filename)
138     with open(filename, 'rb') as f:
139       box_list = pickle.load(f)
140     return self.create_roidb_from_box_list(box_list, gt_roidb)
141 
142   def _load_pascal_annotation(self, index):
143     """
144     Load image and bounding boxes info from XML file in the PASCAL VOC
145     format.
146     """
147     filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
148     tree = ET.parse(filename)
149     objs = tree.findall('object')
150     if not self.config['use_diff']:
151       # Exclude the samples labeled as difficult
152       non_diff_objs = [
153         obj for obj in objs if int(obj.find('difficult').text) == 0]
154       # if len(non_diff_objs) != len(objs):
155       #     print 'Removed {} difficult objects'.format(
156       #         len(objs) - len(non_diff_objs))
157       objs = non_diff_objs
158     num_objs = len(objs)
159 
160     boxes = np.zeros((num_objs, 4), dtype=np.uint16)
161     gt_classes = np.zeros((num_objs), dtype=np.int32)
162     overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
163     # "Seg" area for pascal is just the box area
164     seg_areas = np.zeros((num_objs), dtype=np.float32)
165 
166     # Load object bounding boxes into a data frame.
167     for ix, obj in enumerate(objs):
168       bbox = obj.find('bndbox')
169       # Make pixel indexes 0-based
170       x1 = float(bbox.find('xmin').text) - 1
171       y1 = float(bbox.find('ymin').text) - 1
172       x2 = float(bbox.find('xmax').text) - 1
173       y2 = float(bbox.find('ymax').text) - 1
174       cls = self._class_to_ind[obj.find('name').text.lower().strip()]
175       boxes[ix, :] = [x1, y1, x2, y2]
176       gt_classes[ix] = cls
177       overlaps[ix, cls] = 1.0
178       seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)
179 
180     overlaps = scipy.sparse.csr_matrix(overlaps)
181 
182     return {'boxes': boxes,
183             'gt_classes': gt_classes,
184             'gt_overlaps': overlaps,
185             'flipped': False,
186             'seg_areas': seg_areas}
187 
188   def _get_comp_id(self):
189     comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
190                else self._comp_id)
191     return comp_id
192 
193   def _get_voc_results_file_template(self):
194     # VOCdevkit/results/VOC2007/Main/_det_test_aeroplane.txt
195     filename = self._get_comp_id() + '_det_' + self._image_set + '_{:s}.txt'
196     path = os.path.join(
197       self._devkit_path,
198       'results',
199       'VOC' + self._year,
200       'Main',
201       filename)
202     return path
203 
204   def _write_voc_results_file(self, all_boxes):
205     for cls_ind, cls in enumerate(self.classes):
206       if cls == '__background__':
207         continue
208       print('Writing {} VOC results file'.format(cls))
209       filename = self._get_voc_results_file_template().format(cls)
210       with open(filename, 'wt') as f:
211         for im_ind, index in enumerate(self.image_index):
212           dets = all_boxes[cls_ind][im_ind]
213           if dets == []:
214             continue
215           # the VOCdevkit expects 1-based indices
216           for k in range(dets.shape[0]):
217             f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
218                     format(index, dets[k, -1],
219                            dets[k, 0] + 1, dets[k, 1] + 1,
220                            dets[k, 2] + 1, dets[k, 3] + 1))
221 
222   def _do_python_eval(self, output_dir='output'):
223     annopath = os.path.join(
224       self._devkit_path,
225       'VOC' + self._year,
226       'Annotations',
227       '{:s}.xml')
228     imagesetfile = os.path.join(
229       self._devkit_path,
230       'VOC' + self._year,
231       'ImageSets',
232       'Main',
233       self._image_set + '.txt')
234     cachedir = os.path.join(self._devkit_path, 'annotations_cache')
235     aps = []
236     # The PASCAL VOC metric changed in 2010
237     use_07_metric = True if int(self._year) < 2010 else False
238     print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
239     if not os.path.isdir(output_dir):
240       os.mkdir(output_dir)
241     for i, cls in enumerate(self._classes):
242       if cls == '__background__':
243         continue
244       filename = self._get_voc_results_file_template().format(cls)
245       rec, prec, ap = voc_eval(
246         filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
247         use_07_metric=use_07_metric, use_diff=self.config['use_diff'])
248       aps += [ap]
249       print(('AP for {} = {:.4f}'.format(cls, ap)))
250       with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
251         pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
252     print(('Mean AP = {:.4f}'.format(np.mean(aps))))
253     print('~~~~~~~~')
254     print('Results:')
255     for ap in aps:
256       print(('{:.3f}'.format(ap)))
257     print(('{:.3f}'.format(np.mean(aps))))
258     print('~~~~~~~~')
259     print('')
260     print('--------------------------------------------------------------')
261     print('Results computed with the **unofficial** Python eval code.')
262     print('Results should be very close to the official MATLAB eval code.')
263     print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
264     print('-- Thanks, The Management')
265     print('--------------------------------------------------------------')
266 
267   def _do_matlab_eval(self, output_dir='output'):
268     print('-----------------------------------------------------')
269     print('Computing results with the official MATLAB eval code.')
270     print('-----------------------------------------------------')
271     path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
272                         'VOCdevkit-matlab-wrapper')
273     cmd = 'cd {} && '.format(path)
274     cmd += '{:s} -nodisplay -nodesktop '.format(cfg.MATLAB)
275     cmd += '-r "dbstop if error; '
276     cmd += 'voc_eval(\'{:s}\',\'{:s}\',\'{:s}\',\'{:s}\'); quit;"' \
277       .format(self._devkit_path, self._get_comp_id(),
278               self._image_set, output_dir)
279     print(('Running:\n{}'.format(cmd)))
280     status = subprocess.call(cmd, shell=True)
281 
282   def evaluate_detections(self, all_boxes, output_dir):
283     self._write_voc_results_file(all_boxes)
284     self._do_python_eval(output_dir)
285     if self.config['matlab_eval']:
286       self._do_matlab_eval(output_dir)
287     if self.config['cleanup']:
288       for cls in self._classes:
289         if cls == '__background__':
290           continue
291         filename = self._get_voc_results_file_template().format(cls)
292         os.remove(filename)
293 
294   def competition_mode(self, on):
295     if on:
296       self.config['use_salt'] = False
297       self.config['cleanup'] = False
298     else:
299       self.config['use_salt'] = True
300       self.config['cleanup'] = True
301 
302 
303 if __name__ == '__main__':
304   from datasets.pascal_voc import pascal_voc
305 
306   d = pascal_voc('trainval', '2007')
307   res = d.roidb
308   from IPython import embed;
309 
310   embed()

我们先看涉及到MAP的方法,其他方法暂时放下。

这里通过evaluate_detections方法调用了_do_python_eval方法,后者通过调用voc_eval函数进行了AP和MAP的计算(245-247行)。

  1 # --------------------------------------------------------
  2 # Fast/er R-CNN
  3 # Licensed under The MIT License [see LICENSE for details]
  4 # Written by Bharath Hariharan
  5 # --------------------------------------------------------
  6 from __future__ import absolute_import
  7 from __future__ import division
  8 from __future__ import print_function
  9 
 10 import xml.etree.ElementTree as ET
 11 import os
 12 import pickle
 13 import numpy as np
 14 
 15 def parse_rec(filename):
 16   """ Parse a PASCAL VOC xml file """
 17   tree = ET.parse(filename)
 18   objects = []
 19   for obj in tree.findall('object'):
 20     obj_struct = {}
 21     obj_struct['name'] = obj.find('name').text
 22     obj_struct['pose'] = obj.find('pose').text
 23     obj_struct['truncated'] = int(obj.find('truncated').text)
 24     obj_struct['difficult'] = int(obj.find('difficult').text)
 25     bbox = obj.find('bndbox')
 26     obj_struct['bbox'] = [int(float(bbox.find('xmin').text)),
 27                           int(float(bbox.find('ymin').text)),
 28                           int(float(bbox.find('xmax').text)),
 29                           int(float(bbox.find('ymax').text))]
 30     objects.append(obj_struct)
 31 
 32   return objects
 33 
 34 
 35 def voc_ap(rec, prec, use_07_metric=False):
 36   """ ap = voc_ap(rec, prec, [use_07_metric])
 37   Compute VOC AP given precision and recall.
 38   If use_07_metric is true, uses the
 39   VOC 07 11 point method (default:False).
 40   """
 41   if use_07_metric:
 42     # 11 point metric
 43     ap = 0.
 44     for t in np.arange(0., 1.1, 0.1):
 45       if np.sum(rec >= t) == 0:
 46         p = 0
 47       else:
 48         p = np.max(prec[rec >= t])
 49       ap = ap + p / 11.
 50   else:
 51     # correct AP calculation
 52     # first append sentinel values at the end
 53     mrec = np.concatenate(([0.], rec, [1.]))
 54     mpre = np.concatenate(([0.], prec, [0.]))
 55 
 56     # compute the precision envelope
 57     for i in range(mpre.size - 1, 0, -1):
 58       mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
 59 
 60     # to calculate area under PR curve, look for points
 61     # where X axis (recall) changes value
 62     i = np.where(mrec[1:] != mrec[:-1])[0]
 63 
 64     # and sum (\Delta recall) * prec
 65     ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
 66   return ap
 67 
 68 
 69 def voc_eval(detpath,
 70              annopath,
 71              imagesetfile,
 72              classname,
 73              cachedir,
 74              ovthresh=0.5,
 75              use_07_metric=False,
 76              use_diff=False):
 77   """rec, prec, ap = voc_eval(detpath,
 78                               annopath,
 79                               imagesetfile,
 80                               classname,
 81                               [ovthresh],
 82                               [use_07_metric])
 83 
 84   Top level function that does the PASCAL VOC evaluation.
 85 
 86   detpath: Path to detections
 87       detpath.format(classname) should produce the detection results file.
 88   annopath: Path to annotations
 89       annopath.format(imagename) should be the xml annotations file.
 90   imagesetfile: Text file containing the list of images, one image per line.
 91   classname: Category name (duh)
 92   cachedir: Directory for caching the annotations
 93   [ovthresh]: Overlap threshold (default = 0.5)
 94   [use_07_metric]: Whether to use VOC07's 11 point AP computation
 95       (default False)
 96   """
 97   # assumes detections are in detpath.format(classname)
 98   # assumes annotations are in annopath.format(imagename)
 99   # assumes imagesetfile is a text file with each line an image name
100   # cachedir caches the annotations in a pickle file
101 
102   # first load gt
103   if not os.path.isdir(cachedir):
104     os.mkdir(cachedir)
105   cachefile = os.path.join(cachedir, '%s_annots.pkl' % imagesetfile)
106   # read list of images
107   with open(imagesetfile, 'r') as f:
108     lines = f.readlines()
109   imagenames = [x.strip() for x in lines]    #test.txt中的所有标号
110 
111   # load annotations
112   if not os.path.isfile(cachefile):    
113     recs = {}
114     for i, imagename in enumerate(imagenames):
115       recs[imagename] = parse_rec(annopath.format(imagename))
116       if i % 100 == 0:
117         print('Reading annotation for {:d}/{:d}'.format(
118           i + 1, len(imagenames)))
119     # save
120     print('Saving cached annotations to {:s}'.format(cachefile))
121     with open(cachefile, 'wb') as f:
122       pickle.dump(recs, f)
123   else:
124     # load
125     with open(cachefile, 'rb') as f:
126       try:
127         recs = pickle.load(f)
128       except:
129         recs = pickle.load(f, encoding='bytes')
130 
131   # extract gt objects for this class
132   class_recs = {}
133   npos = 0
134   for imagename in imagenames:
135     R = [obj for obj in recs[imagename] if obj['name'] == classname]
136     bbox = np.array([x['bbox'] for x in R])
137     if use_diff:
138       difficult = np.array([False for x in R]).astype(np.bool)      
139     else:
140       difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
141     det = [False] * len(R)
142     npos = npos + sum(~difficult)
143     class_recs[imagename] = {'bbox': bbox,
144                              'difficult': difficult,
145                              'det': det}
146 
147   # read dets
148   detfile = detpath.format(classname)
149   with open(detfile, 'r') as f:
150     lines = f.readlines()
151 
152   splitlines = [x.strip().split(' ') for x in lines]
153   image_ids = [x[0] for x in splitlines]
154   confidence = np.array([float(x[1]) for x in splitlines])
155   BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
156 
157   nd = len(image_ids)
158   tp = np.zeros(nd)
159   fp = np.zeros(nd)
160 
161   if BB.shape[0] > 0:
162     # sort by confidence
163     sorted_ind = np.argsort(-confidence)
164     sorted_scores = np.sort(-confidence)
165     BB = BB[sorted_ind, :]
166     image_ids = [image_ids[x] for x in sorted_ind]
167 
168     # go down dets and mark TPs and FPs
169     for d in range(nd):
170       R = class_recs[image_ids[d]]
171       bb = BB[d, :].astype(float)
172       ovmax = -np.inf
173       BBGT = R['bbox'].astype(float)
174 
175       if BBGT.size > 0:
176         # compute overlaps
177         # intersection
178         ixmin = np.maximum(BBGT[:, 0], bb[0])
179         iymin = np.maximum(BBGT[:, 1], bb[1])
180         ixmax = np.minimum(BBGT[:, 2], bb[2])
181         iymax = np.minimum(BBGT[:, 3], bb[3])
182         iw = np.maximum(ixmax - ixmin + 1., 0.)
183         ih = np.maximum(iymax - iymin + 1., 0.)
184         inters = iw * ih
185 
186         # union
187         uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
188                (BBGT[:, 2] - BBGT[:, 0] + 1.) *
189                (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
190 
191         overlaps = inters / uni
192         ovmax = np.max(overlaps)
193         jmax = np.argmax(overlaps)
194 
195       if ovmax > ovthresh:
196         if not R['difficult'][jmax]:
197           if not R['det'][jmax]:
198             tp[d] = 1.
199             R['det'][jmax] = 1
200           else:
201             fp[d] = 1.
202       else:
203         fp[d] = 1.
204 
205   # compute precision recall
206   fp = np.cumsum(fp)
207   tp = np.cumsum(tp)
208   rec = tp / float(npos)
209   # avoid divide by zero in case the first detection matches a difficult
210   # ground truth
211   prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
212   ap = voc_ap(rec, prec, use_07_metric)
213 
214   return rec, prec, ap

 

voc_eval(filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5, use_07_metric=use_07_metric, use_diff=self.config['use_diff'])

def voc_eval(detpath, annopath, imagesetfile, classname, cachedir, ovthresh=0.5, use_07_metric=False, use_diff=False):

filename: detpath: Path to detections 存储detection结果的pkl文件地址

annopath: 存储Annotations的地址

imagesetfile: 图片集的txt文档

classname: 当前的class

cachedir: 存储Annotations的pkl所在目录(可能不存在)

ovthresh=0.5: IoU的threshold,默认为0.5

use_07_metric=Flase: 是否使用2007PASCAL_VOC的MAP计算规则

use_diff=False: 是否考虑difficult的检测样本

 

经过一番数据处理,得到了:

BB: 当前class的所有proposal bbox (predicted)

image_ids: 当前imageset的所有image序号

class_recs: image所包含的当前class的bbox (GT)

 

 1   if BB.shape[0] > 0:
 2     # sort by confidence
 3     #'''
 4     sorted_ind = np.argsort(-confidence)
 5     sorted_scores = np.sort(-confidence)
 6     BB = BB[sorted_ind, :]    # 现在的BB是按照conf降序排列的所有predicted bbox
 7     image_ids = [image_ids[x] for x in sorted_ind]    # image_id 是BB每组bbox所属于的image的序号
 8     
 9     #'''
10      
11     # go down dets and mark TPs and FPs
12     for d in range(nd):              #对所有proposal bbox 遍历
13       R = class_recs[image_ids[d]]   # 找到当前bbox对应的image
14       bb = BB[d, :].astype(float)    # bb 为当前proposal bbox的坐标
15       ovmax = -np.inf                # 设置np极小值
16       BBGT = R['bbox'].astype(float) 
17 
18       if BBGT.size > 0:
19         # compute overlaps
20         # intersection
21         ixmin = np.maximum(BBGT[:, 0], bb[0])
22         iymin = np.maximum(BBGT[:, 1], bb[1])
23         ixmax = np.minimum(BBGT[:, 2], bb[2])
24         iymax = np.minimum(BBGT[:, 3], bb[3])
25         iw = np.maximum(ixmax - ixmin + 1., 0.)
26         ih = np.maximum(iymax - iymin + 1., 0.)
27         inters = iw * ih
28 
29         # union
30         uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
31                (BBGT[:, 2] - BBGT[:, 0] + 1.) *
32                (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)
33 
34         overlaps = inters / uni
35         ovmax = np.max(overlaps)
36         jmax = np.argmax(overlaps)
37         print(overlaps)
38 
39       if ovmax > ovthresh:
40         if not R['difficult'][jmax]:    
41           if not R['det'][jmax]:        #是否已经被检测过
42             tp[d] = 1.
43             R['det'][jmax] = 1
44           else:
45             fp[d] = 1.
46       else:
47         fp[d] = 1.

 

疑惑:

这里的Recall计算(voc_eval.py 208行)使用了:

rec = tp / float(npos),npos实际上是所有bbox-GT的数量,并不应该等于tp+fn吧?当且仅当:fn(包含但未被检测出bbox的image数量)==npos-tp(未被检测出的bbox数量)

 

ref: 1. https://datascience.stackexchange.com/questions/25119/how-to-calculate-map-for-detection-task-for-the-pascal-voc-challenge

2. http://mp.weixin.qq.com/s/FaNC9RppIhPf6T_qAz3Slg

3. https://ils.unc.edu/courses/2013_spring/inls509_001/lectures/10-EvaluationMetrics.pdf

4. https://stats.stackexchange.com/questions/260430/average-precision-in-object-detection/263758#263758

你可能感兴趣的:(Pytorch-Faster-RCNN 中的 MAP 实现 (解析imdb.py 和 pascal_voc.py))