Faster R-CNN画PR曲线

参考:https://github.com/rbgirshick/py-faster-rcnn/issues/670

在pascal_voc.py里添加几行代码即可:

1,文件头部:

import matplotlib.pyplot as plt
import pylab as pl
from sklearn.metrics import precision_recall_curve
from itertools import cycle

2,_do_python_eval函数:

 def _do_python_eval(self, output_dir='output'):
    annopath = os.path.join(
      self._devkit_path,
      'VOC' + self._year,
      'Annotations',
      '{:s}.xml')
    imagesetfile = os.path.join(
      self._devkit_path,
      'VOC' + self._year,
      'ImageSets',
      'Main',
      self._image_set + '.txt')
    cachedir = os.path.join(self._devkit_path, 'annotations_cache')
    aps = []
    # The PASCAL VOC metric changed in 2010
    use_07_metric = True if int(self._year) < 2010 else False
    print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
    if not os.path.isdir(output_dir):
      os.mkdir(output_dir)
    for i, cls in enumerate(self._classes):
      if cls == '__background__':
        continue
      filename = self._get_voc_results_file_template().format(cls)
      rec, prec, ap = voc_eval(
        filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
        use_07_metric=use_07_metric)
      aps += [ap]
      pl.plot(rec, prec, lw=2, 
              label='Precision-recall curve of class {} (area = {:.4f})'
                    ''.format(cls, ap))
      print(('AP for {} = {:.4f}'.format(cls, ap)))
      with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
        pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)

    pl.xlabel('Recall')
    pl.ylabel('Precision')
    plt.grid(True)
    pl.ylim([0.0, 1.05])
    pl.xlim([0.0, 1.0])
    pl.title('Precision-Recall')
    pl.legend(loc="upper right")     
    plt.show()

    print(('Mean AP = {:.4f}'.format(np.mean(aps))))
    print('~~~~~~~~')
    print('Results:')
    for ap in aps:
      print(('{:.3f}'.format(ap)))
    print(('{:.3f}'.format(np.mean(aps))))
    print('~~~~~~~~')
    print('')
    print('--------------------------------------------------------------')
    print('Results computed with the **unofficial** Python eval code.')
    print('Results should be very close to the official MATLAB eval code.')
    print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
    print('-- Thanks, The Management')
    print('--------------------------------------------------------------')

然后运行test_net.py,就可以得到如下图的PR曲线。如果想比较多条曲线,可以先把rec, prec数据存起来再画图。

Faster R-CNN画PR曲线_第1张图片

你可能感兴趣的:(目标检测)