Faster rcnn 中的nms解读

最近在看Faster rcnn相关的东西,也真好碰巧面试遇到手写nms代码这个环节,故本文将其记录下来以备日后使用。


paper里面的原版解读:

  1. import numpy as np  
  2.   
  3. def py_cpu_nms(dets, thresh):  
  4.     """Pure Python NMS baseline."""  
  5.     # 所有图片的坐标信息,字典形式储存??  
  6.     x1 = dets[:, 0]  
  7.     y1 = dets[:, 1]  
  8.     x2 = dets[:, 2]  
  9.     y2 = dets[:, 3]  
  10.     scores = dets[:, 4]  
  11.   
  12.     areas = (x2 - x1 + 1) * (y2 - y1 + 1# 计算出所有图片的面积  
  13.     order = scores.argsort()[::-1# 图片评分按升序排序  
  14.   
  15.     keep = [] # 用来存放最后保留的图片的相应评分  
  16.     while order.size > 0:   
  17.         i = order[0# i 是还未处理的图片中的最大评分  
  18.         keep.append(i) # 保留改图片的值  
  19.         # 矩阵操作,下面计算的是图片i分别与其余图片相交的矩形的坐标  
  20.         xx1 = np.maximum(x1[i], x1[order[1:]])   
  21.         yy1 = np.maximum(y1[i], y1[order[1:]])  
  22.         xx2 = np.minimum(x2[i], x2[order[1:]])  
  23.         yy2 = np.minimum(y2[i], y2[order[1:]])  
  24.   
  25.         # 计算出各个相交矩形的面积  
  26.         w = np.maximum(0.0, xx2 - xx1 + 1)  
  27.         h = np.maximum(0.0, yy2 - yy1 + 1)  
  28.         inter = w * h  
  29.         # 计算重叠比例  
  30.         ovr = inter / (areas[i] + areas[order[1:]] - inter)  
  31.   
  32.         #只保留比例小于阙值的图片,然后继续处理  
  33.         inds = np.where(ovr <= thresh)[0]  
  34.         order = order[inds + 1]  
  35.   
  36.     return keep 


自己仿照的实现:

import numpy as np

def nms(boxes,thresh):
    if len(boxes) == 0:
        return []

    if boxes.dtype.kind == 'i':
        boxes = boxes.astype('float')

    pick = []
    x1 = boxes[:,0]
    y1 = boxes[:,1]
    x2 = boxes[:,2]
    y2 = boxes[:,3]
    score = boxes[:,4]

    aera = (x2-x1+1)*(y2-y1+1)
    #按照评分的升序排列
    index = np.argsort(score)[::-1]

    while len(index) > 0:
        i = index[0]
        pick.append(i)
        xx1 = np.maximum(x1[i],x1[index[1:]])
        yy1 = np.maximum(y1[i],y1[index[1:]])
        xx2 = np.minimum(x2[i],x2[index[1:]])
        yy2 = np.minimum(y2[i],y2[index[1:]])

        w = np.maximum(0,xx2-xx1+1)
        h = np.maximum(0,yy2-yy1+1)

        if w>0 and h>0:
            overlap = (w*h)/(aera[i] + aera[1:] - (w*h))
        inds = np.where(overlap <= thresh)[0]
        index = index[inds + 1]
不得不说,np.where这个操作还是很经典的,我自己最开始是采用的np.delete的方式,但是后来又看了原版实现。

你可能感兴趣的:(Faster rcnn 中的nms解读)