直接读取…\py-faster-rcnn-master\data\VOCdevkit2007\VOC2007\ImageSets\Main\test.txt中的测试集编号进行测试。
原demo.py文件:实现的是检测一张图片,然后对该图片的每一类检测结果,单独显示。
修改后:从txt中读取要检测的图片名称,进行批量检测,每张图上都有所有类的检测结果,然后保存到自定义目录。
运行代码前需要将全部图片(或者仅是测试集图片)复制到…\py-faster-rcnn-master\data\demo\目录下。
import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse
from PIL import Image
CLASSES = ('__background__', 'ali', 'ddg1000', 'jg3', 'jilong',
)
NETS = {'vgg16': ('VGG16',
'VGG16_faster_rcnn_final.caffemodel'),
'vgg1024': ('VGG_CNN_M_1024',
'vgg_cnn_m_1024_faster_rcnn_iter_70000.caffemodel'),
'zf': ('ZF',
'ZF_faster_rcnn_final.caffemodel')}
# 此处删除原来的函数“def vis_detections(im, class_name, dets, thresh=0.5):”
def demo(net, image_name):
"""Detect object classes in an image using pre-computed object proposals."""
# Load the demo image
im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
im = cv2.imread(im_file)
# Detect all object classes and regress object bounds
timer = Timer()
timer.tic()
scores, boxes = im_detect(net, im)
timer.toc()
print ('Detection took {:.3f}s for '
'{:d} object proposals').format(timer.total_time, boxes.shape[0])
# Visualize detections for each class
thresh = 0.7
# CONF_THRESH = 0.7
NMS_THRESH = 0.3
# 打开图片
im = im[:, :, (2, 1, 0)]
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal', alpha=0.5)
# 对每一类的每一个目标,在图片上生成框
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
# 添加以下代码
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
continue
for i in inds:
bbox = dets[i, :4]
score = dets[i, -1]
ax.add_patch(
plt.Rectangle((bbox[0], bbox[1]),
bbox[2] - bbox[0],
bbox[3] - bbox[1], fill=False,
edgecolor='red', linewidth=1.5)
)
ax.text(bbox[0], bbox[1] - 2,
'{:s} {:.3f}'.format(cls, score),
bbox=dict(facecolor='blue', alpha=0.5),
fontsize=14, color='white')
plt.axis('off')
plt.tight_layout()
plt.draw()
image_name = image_name.replace('jpg', 'png')
# 设置保存目录
plt.savefig('E:\jianchuan\py-faster-rcnn-master_4types_2km\save_results\save_demo_test/' + image_name)
print("E:\jianchuan\py-faster-rcnn-master_4types_2km\save_results\save_demo_test/{}".format(image_name))
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Faster R-CNN demo')
parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
default=0, type=int)
parser.add_argument('--cpu', dest='cpu_mode',
help='Use CPU mode (overrides --gpu)',
action='store_true')
parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
choices=NETS.keys(), default='vgg1024')
args = parser.parse_args()
return args
if __name__ == '__main__':
cfg.TEST.HAS_RPN = True # Use RPN for proposals
args = parse_args()
prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
'faster_rcnn_end2end', 'test.prototxt')
caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',
NETS[args.demo_net][1])
if not os.path.isfile(caffemodel):
raise IOError(('{:s} not found.\nDid you run ./data/script/'
'fetch_faster_rcnn_models.sh?').format(caffemodel))
if args.cpu_mode:
caffe.set_mode_cpu()
else:
caffe.set_mode_gpu()
caffe.set_device(args.gpu_id)
cfg.GPU_ID = args.gpu_id
net = caffe.Net(prototxt, caffemodel, caffe.TEST)
print ('\n\nLoaded network {:s}'.format(caffemodel))
# 读取txt文件并循环检测
fi = open('E:\jianchuan\py-faster-rcnn-master_4types_2km\data\VOCdevkit2007\VOC2007\ImageSets\Main\\test.txt')
txt = fi.readlines()
im_names = []
for line in txt:
line = line.strip('\n')
line = line.replace('\r', '')
line = (line + '.jpg')
im_names.append(line)
print(im_names)
fi.close()
for im_name in im_names:
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('Demo for data/demo/{}'.format(im_name))
demo(net, im_name)
# plt.show()
参考:Faster批量测试且所有类检测结果都显示在一张图上(TensorFlow实现faster-rcnn)