对模型评估,我们需要得到的文件:
1.各类检测到的目标框txt文件。需要通过下面对程序生成。
txt文件内容如下,第一列是图像名字(不带后缀),第二列是置信度,剩下依次是xmin、ymin、xmax、ymax
2.Annotations文件。制作VOC数据集时候就会有,.\data\VOCdevkit2007\VOC2007\Annotations下
3.验证图像名字列表,这4个文件中,评估时候用到的是test.txt。
4.test_net.py 文件
放Faster-RCNN-TensorFlow-Python3.5-master 根文件夹。
#!/usr/bin/env python
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen, based on code from Ross Girshick
# --------------------------------------------------------
"""
Demo script showing detections in sample images.
See README.md for installation instructions before running.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import tensorflow as tf
from lib.nets.vgg16 import vgg16
from lib.datasets.factory import get_imdb
from lib.utils.test import test_net
# NETS = {'vgg16': ('vgg16_faster_rcnn_iter_70000.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_40000.ckpt',)} #训练输出模型
DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval+voc_2012_trainval',)}
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Tensorflow Faster R-CNN test')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16 res101]',
choices=NETS.keys(), default='vgg16')
parser.add_argument('--dataset', dest='dataset', help='Trained dataset [pascal_voc pascal_voc_0712]',
choices=DATASETS.keys(), default='pascal_voc')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# model path
demonet = args.demo_net
dataset = args.dataset
tfmodel = os.path.join('output', demonet, DATASETS[dataset][0], 'default', NETS[demonet][0]) #模型路径
# 获得模型文件名称
filename = (os.path.splitext(tfmodel)[0]).split('\\')[-1]
filename = 'default' + '/' + filename
imdb = get_imdb("voc_2007_test") # 得到
imdb.competition_mode('competition mode')
if not os.path.isfile(tfmodel + '.meta'):
print(tfmodel)
raise IOError(('{:s} not found.\nDid you download the proper networks from '
'our server and place them properly?').format(tfmodel + '.meta'))
# set config
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
# init session
sess = tf.Session(config=tfconfig)
# load network
if demonet == 'vgg16':
net = vgg16(batch_size=1)
# elif demonet == 'res101':
# net = resnetv1(batch_size=1, num_layers=101)
else:
raise NotImplementedError
net.create_architecture(sess, "TEST", 9, # 记得修改第3个参数为:类别数量+1
tag='default', anchor_scales=[8, 16, 32])
saver = tf.train.Saver()
saver.restore(sess, tfmodel)
print('Loaded network {:s}'.format(tfmodel))
test_net(sess, net, imdb, filename, max_per_image=100)
sess.close()
需要自行修改的地方:
1.改为自己训练输出对ckpt文件名
NETS = {'vgg16': ('vgg16_faster_rcnn_iter_40000.ckpt',)} #训练输出模型
2.修改第3个参数为:类别数量+1
net.create_architecture(sess, "TEST", 9, # 记得修改第3个参数为:类别数量+1
tag='default', anchor_scales=[8, 16, 32])
1.36行左右:修改类别名,注意__background__保留
self._classes = ('__background__', # always index 0
'cls1', 'cls2', 'cls3', 'cls4', 'cls5', 'cls6')
2.函数_do_python_eval中添加画图代码
3.确保有这个文件夹,可修改也可自行创建
1.修改路径
2.修改文件路径和名
最后运行test_net.py,程序会打印出评估结果
默认在.\data\VOCdevkit2007\annotations_cache下
部分路径可能要根据实际情况进行修改,切勿生搬硬套。
参考博客:
https://blog.csdn.net/sihaiyinan/article/details/89417963#commentBox
https://blog.csdn.net/ff_xun/article/details/82354999#commentBox