输入:
检测到的Boxes(同一个物体可能被检测到很多Boxes,每个box均有分类score)
输出:
最优的Box.(不止一个)
过程:
去除冗余的重叠 Boxes,对全部的 Boxes 进行迭代-遍历-消除.
1.将所有框的得分排序,选中最高分及其对应的框;
2. 遍历其余的框,如果和当前最高分框的重叠面积(IOU)大于一定阈值,则将框删除;
3. 从未处理的框中继续选一个得分最高的,重复上述过程.
假设某物体检测到 4 个 Boxes,每个 Box 分别对应一个类别 Score,根据 Score 从小到大排列依次为,(B1, S1), (B2, S2), (B3, S3), (B4, S4). S4 > S3 > S2 > S1.
Step 1. 根据Score 大小,从 Box B4 框开始;
Step 2. 分别计算 B1, B2, B3 与 B4 的重叠程度 IoU,判断是否大于预设定的阈值;如果大于设定阈值,则舍弃该 Box;同时标记保留的 Box. 假设 B3 与 B4 的阈值超过设定阈值,则舍弃 B3,标记 B4 为要保留的 Box;
Step 3. 从剩余的 Boxes 中 B1, B2 中选取 Score 最大的 B2, 然后计算 B2 与 剩余的 B1 的重叠程度 IoU;如果大于设定阈值,同样丢弃该 Box;同时标记保留的 Box.
重复以上过程,直到找到全部的保留 Boxes.
(最后找的Boxes不一定只有一个)
这里为什么是 ? 其实是为了减少误判的可能。
#测试 NMS(非极大值抑制算法)
import numpy as np
import matplotlib.pyplot as plt
#定义盒子
boxes=np.array([[100,100,210,220,0.71],
[250,250,420,420,0.8],
[220,200,320,330,0.92],
[100,100,210,210,0.72],
[230,240,325,330,0.81],
[220,230,315,340,0.9]])
def iou(xmin,ymin,xmax,ymax,areas,lastInd,beforeInd,threshold):
#xmin[lastInd]一个单独数据 m ,xmin[beforeInd]一组数据
#比较m和后者的每个数据,m大,则后一组数据对应位置为 m
#当前bbox和剩下bbox之间的交叉区域
# 选择大于x1,y1和小于x2,y2的区域
xminTmp=np.maximum(xmin[lastInd],xmin[beforeInd])
yminTmp=np.maximum(ymin[lastInd],ymin[beforeInd])
xmaxTmp=np.minimum(xmax[lastInd],xmax[beforeInd])
ymaxTmp=np.minimum(ymax[lastInd],ymax[beforeInd])
width=np.maximum(0.0,xmaxTmp-xminTmp+1)
height=np.maximum(0.0,ymaxTmp-yminTmp+1)
#print(width,height)
#计算存活 box 和 last 指向box的交集
interSection=width*height
union=areas[beforeInd]+areas[lastInd]-interSection
iou_value=interSection/union
print("iou_value",iou_value)
indexOutput=[item[0] for item in zip(beforeInd,iou_value) if item[1]<=threshold]
return indexOutput
def nms(boxes,threshold):
#判断是否是ndarray的数据
assert isinstance(boxes,np.ndarray)
assert boxes.shape[1]==5
xmin=boxes[:,0]
ymin=boxes[:,1]
xmax=boxes[:,2]
ymax=boxes[:,3]
scores=boxes[:,4]
#计算每个盒子的区域
#print(xmax-xmin+1)
areas=(xmax-xmin+1)*(ymax-ymin+1)
#每个盒子的分数升序排列
scoreSort=sorted(list(enumerate(scores)),key=lambda item:item[1])
#save 升序排列的 index [0, 3, 1, 4, 5, 2]
index=[item[0] for item in scoreSort]
keep=[]
while len(index)>0:
lastInd=index[-1]
keep.append(lastInd)
#print(index[:-1])
#计算 iou index[:-1]:打印 [0,3,1,4,5]
index=iou(xmin,ymin,xmax,ymax,areas,lastInd,index[:-1],threshold)
#print(keep)
return keep
def bbox(dets,c='k'):
x1=dets[:,0]
y1=dets[:,1]
x2=dets[:,2]
y2=dets[:,3]
plt.plot([x1,x2],[y1,y1],c)
plt.plot([x1,x1],[y1,y2],c)
plt.plot([x1,x2],[y2,y2],c)
plt.plot([x2,x2],[y1,y2],c)
plt.title("after nms")
if __name__=='__main__':
bbox(boxes,'k')
remain=nms(boxes,threshold=0.6)
#print(boxes[remain])
bbox(boxes[remain],'y')
运行结果(黄色为可能的选框):
相关推荐:
https://chenzomi12.github.io/2016/12/14/YOLO-nms/