参考:https://blog.csdn.net/10km/article/details/68926498
#增加ax参数
def vis_detections(im, class_name, dets, ax, thresh=0.5):
"""Draw detected bounding boxes."""
inds = np.where(dets[:, -1] >= thresh)[0]
if len(inds) == 0:
return
# 删除这三行
# im = im[:, :, (2, 1, 0)]
# fig, ax = plt.subplots(figsize=(12, 12))
# ax.imshow(im, aspect='equal')
# 删除这三行
# plt.axis('off')
# plt.tight_layout()
# plt.draw()
# Visualize detections for each class
CONF_THRESH = 0.8
NMS_THRESH = 0.3
# 将vis_detections 函数中for 循环之前的3行代码移动到这里
im = im[:, :, (2, 1, 0)]
fig,ax = plt.subplots(figsize=(12, 12))
ax.imshow(im, aspect='equal')
for cls_ind, cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[:, cls_ind]
dets = np.hstack((cls_boxes,
cls_scores[:, np.newaxis])).astype(np.float32)
keep = nms(dets, NMS_THRESH)
dets = dets[keep, :]
#将ax做为参数传入vis_detections
vis_detections(im, cls, dets, ax,thresh=CONF_THRESH)
# 将vis_detections 函数中for 循环之后的3行代码移动到这里
plt.axis('off')
plt.tight_layout()
plt.draw()
测试自己的图片需要修改:
1. CLASSES = ('__background__',
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
图片类别修改
2. im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name) #demo为测试图片路径
3. net.create_architecture("TEST", 11, tag='default', anchor_scales = [8,16,32] ) #更改11为类别数加1
4. im_names #测试图片路径,保存为文件夹的修改方式
im_names = os.listdir(“ ”) #测试图片所在位置
for im_name in im_names:
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print('Demo for data/demo/{}'.format(im_name))
demo(sess, net, im_name)
#保存测试图片所在位置,并设置输出格式
plt.savefig(“ ”)