tensorflow版本fasterrcnn模型评价指标都在lib/datasets/passcal_voc,voc_eval.
中。
首先下面在找到passcal_voc在开头加入这几句:
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
from itertools import cycle
import pylab as pl
然后找到这个函数更改如下:
def _do_python_eval(self, output_dir='output'):
annopath = self._devkit_path + '\\VOC' + self._year + '\\Annotations\\' + '{:s}.xml'
imagesetfile = os.path.join(
self._devkit_path,
'VOC' + self._year,
'ImageSets',
'Main',
self._image_set + '.txt')
cachedir = os.path.join(self._devkit_path, 'annotations_cache')
aps = []
#加
recs=[]
precs=[]
#结束
# The PASCAL VOC metric changed in 2010
use_07_metric = True if int(self._year) < 2010 else False
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
for i, cls in enumerate(self._classes):
if cls == '__background__':
continue
filename = self._get_voc_results_file_template().format(cls)
rec, prec, ap = voc_eval(
filename, annopath, imagesetfile, cls, cachedir, ovthresh=0.5,
use_07_metric=use_07_metric)
aps += [ap]
#加
'''
recs += [rec[-1]]
precs += [prec[-1]]
print('AP for {} = {:.4f}'.format(cls, ap))
print('recall for {} = {:.4f}'.format(cls, rec[-1]))
print('precision for {} = {:.4f}'.format(cls, prec[-1]))
with open(os.path.join(output_dir, cls + '_pr.pkl'), 'w') as f:
pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
print('Mean AP = {:.4f}'.format(np.mean(aps)))
print('~~~~~~~~')
print('Results:')
'''
pl.plot(rec, prec, lw=2,
label='Precision-recall curve of class {} (area = {:.4f})'
''.format(cls, ap))
print(('AP for {} = {:.4f}'.format(cls, ap)))
with open(os.path.join(output_dir, cls + '_pr.pkl'), 'wb') as f:
pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
#加
pl.xlabel('Recall')
pl.ylabel('Precision')
plt.grid(True)
pl.ylim([0.0, 1.2])
pl.xlim([0.0, 1.0])
pl.title('Precision-Recall')
pl.legend(loc="upper right")
plt.show()
print(('Mean AP = {:.4f}'.format(np.mean(aps))))
print('~~~~~~~~')
print('Results:')
for ap in aps:
print(('{:.3f}'.format(ap)))
print(('{:.3f}'.format(np.mean(aps))))
print('~~~~~~~~')
print('')
print('--------------------------------------------------------------')
print('Results computed with the **unofficial** Python eval code.')
print('Results should be very close to the official MATLAB eval code.')
print('Recompute with `./tools/reval.py --matlab ...` for your paper.')
print('-- Thanks, The Management')
print('--------------------------------------------------------------')
这个文件的另外一个函数注释掉这几行(为例生成检测后的每一类的预测框的文本文件)
def evaluate_detections(self, all_boxes, output_dir):
self._write_voc_results_file(all_boxes)
self._do_python_eval(output_dir)
if self.config['matlab_eval']:
self._do_matlab_eval(output_dir)
#if self.config['cleanup']:
# for cls in self._classes:
# if cls == '__background__':
# continue
# filename = self._get_voc_results_file_template().format(cls)
# os.remove(filename)
voc_eval.py文件中做如下更改:
def parse_rec(filename):#读取标注xml文件
""" Parse a PASCAL VOC xml file """
tree = ET.parse(''+filename)
objects = []#./data/VOCdevkit2007/VOC2007/Annotations/
for obj in tree.findall('object'):
obj_struct = {}
obj_struct['name'] = obj.find('name').text
obj_struct['pose'] = obj.find('pose').text
obj_struct['truncated'] = int(obj.find('truncated').text)
obj_struct['difficult'] = int(obj.find('difficult').text)
bbox = obj.find('bndbox')
obj_struct['bbox'] = [int(bbox.find('xmin').text),
int(bbox.find('ymin').text),
int(bbox.find('xmax').text),
int(bbox.find('ymax').text)]
objects.append(obj_struct)
return objects
(让objects之前为空列表)
在faster-rcnn-tensorflow-python3.5-master文件夹下新建
test-net.py.
# !/usr/bin/env python
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------
"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import tensorflow as tf
from lib.nets.vgg16 import vgg16
from lib.datasets.factory import get_imdb
from lib.utils.test import test_net
# NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_5000.ckpt',)} # 训练输出模型
DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN test')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
choices=NETS.keys(), default='vgg16')
parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
choices=DATASETS.keys(), default='pascal_voc')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# model path
demonet = args.demo_net
dataset = args.dataset
tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0]) # 模型路径
# 获得模型文件名称
filename = (os.path.splitext(tfmodel)[0]).split('\\')[-1]
filename = 'default' + '/' + filename
imdb = get_imdb("voc_2007_test") # 得到
imdb.competition_mode('competition mode')
if not os.path.isfile(tfmodel + '.meta'):
print(tfmodel)
raise IOError(('{:s} not found.\nDid you download the proper networks from '
'our server and place them properly?').format(tfmodel + '.meta'))
# set config
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
# init session
sess = tf.Session(config=tfconfig)
# load network
if demonet == 'vgg16':
net = vgg16(batch_size=1)
# elif demonet == 'res101':
# net = resnetv1(batch_size=1, num_layers=101)
else:
raise NotImplementedError
net.create_architecture(sess, "TEST", 8, # 记得修改第3个参数为:类别数量+1
tag='default', anchor_scales=[8, 16, 32])
saver = tf.train.Saver()
saver.restore(sess, tfmodel)
print('Loaded network {:s}'.format(tfmodel))
print(filename)
test_net(sess, net, imdb, filename, max_per_image=100)
sess.close()
(记得修改相应的路径和分类的类别数)
点击运行即可生成几个文本文件和pr曲线
参考博客1
参考博客2