获取Faster RCNN最终候选框坐标值

经过之前的训练,小鱼对Faster RCNN网络实现检测功能已经了解了一点,接下来就是利用网络的检测结果实现多目标的跟踪.这个专题就用来记录一些实现跟踪道路上的小知识点.

今天小鱼分享的是:如何利用训练好的网络得到测试图片的候选框坐标?

在运行~/py-faster-rcnn/tools/demo.py这个文件时对测试集的目标进行了候选框的可视化,也就是demo.py中的def vis_detections函数.这里可以参考demo.py代码解析了解该代码的主要三个功能:候选框结果的可视化;检测目标框的获取;参数的获取

小鱼就是从这里得到最终的候选框坐标,具体方法为:
1.分析可视化什么东西
可视化函数为:

def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)

    plt.axis('off')
    plt.tight_layout()
    plt.draw()

从函数中可以得到是可视化了检测的目标框bbox,并在每个目标框上表明检测精度score.
2.找到可视化的东西怎么得到的
找出哪里生成bbox,score,从可视化函数的这两个参数可以追溯到下一个函数中调用了可视化函数,即demo函数:

def demo(net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    **scores, boxes = im_detect(net, im)**
    timer.toc()
    print ('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0])

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        vis_detections(im, cls, dets, thresh=CONF_THRESH)

从这个函数中可以看到也有bbox,score,代码加粗部分.这里又追溯到im_detect(net, im)函数,这个函数出现在~/py-faster-rcnn/lib/fast_rcnn/test.py中,我们就转到这个函数,可以看到在im_detect(net, im)函数中有完整的计算bbox,score代码并return scores, pred_boxes.
这里需要注意的是,在im_detect(net, im)函数中,182行加入打印代码:

print pred_boxes.shape
return scores, pred_boxes

即return代码上方直接打印出输出bbox的维度,得到一个300*(4(K+1))的矩阵,k代表类别数.是最后输入给Fast RCNN的300个候选框,不是我们想要的最终可视化的候选框.
3.將这些需要的东西利用起来
找完之后,就开始各取所需.这里小鱼需要最终的候选框,那就是在可视化中的候选框,我就只需要在~/py-faster-rcnn/tools/demo.py中加入print输出指令即可,如在53行左右加入print bbox
最终的结果如下格式:

Demo for data/demo/00109.jpg
Detection took 0.142s for 300 object proposals
[ 1198.58422852  1014.32049561  1291.8581543   1123.05639648]
[ 675.29943848  634.83068848  766.93762207  724.48535156]
[ 1463.50634766   131.7862854   1548.50048828   223.23049927]
[ 1021.40447998   367.55706787  1138.07788086   479.88537598]
[ 1228.62817383   665.61010742  1330.26538086   781.8638916 ]
[ 1069.45117188   457.67938232  1159.40161133   542.62628174]
[ 588.99707031  251.40000916  685.6192627   361.50817871]
[ 1058.63061523   542.11383057  1131.95068359   612.54180908]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

每一行为~/data/demo/00109.jpg图片中一个目标框的坐标,总共多少行就证明该图片上有多少个目标被检测出来.

你可能感兴趣的:(基于检测的多目标跟踪,faster-r-cnn,demo-py,获取候选框)