在目标检测任务中,一个检测目标通常会产生多个检测框,为了合理缩减输出检测框数量,往往需要在目标检测模型输出前添加非极大抑制
Non-Maximum Suppression, NMS
。非极大抑制算法能够删除冗余的检测框,并优先保留可靠性最高的结果。
假设有以下输入det
:
[[x_center, y_center, width, height, prob, labels...],
[x_center, y_center, width, height, prob, labels...],
...
[x_center, y_center, width, height, prob, labels...]]
和一个阙值theta
。其中,x_center,y_center
为检测框中心点坐标,width,height
为检测框宽度和高度,prob
表示检测框正确检测的概率,labels
则为一个独热向量,用于表示检测框类别。
非极大抑制算法基于上述输入的主要思想是:首先选择一个prob
最高的检测框A
,然后计算A
与剩余所有检测框的交并比Intersection over Union, IoU
以衡量两个检测框之间的相似度,交并比高于阈值theta
的检测框可以被认为是冗余的劣质检测框,直接删除。接下来再从除了A
以外的保留检测框中选择一个prob
最高的检测框,重复上述过程,直到所有检测框都被遍历一遍即可。
算法伪代码如下:
function nms(det, theta):
index = 所有检测框的索引列表
keep = 存储保留的检测框索引,初始化为空列表
while index is not empty:
i = prob值最高的检测框的索引
index.remove(i)
keep.add(i)
for each j in index:
if IoU(det[i], det[j]) > theta:
index.remove(j)
return det[keep]
所谓交并比,就是矩阵A
与矩阵B
交集的面积除以A
与B
并集的面积。图示如下:
两矩阵并集的面积可以通过公式size_A + size_B - inter_size
得到,因此IoU
计算的重点就是计算两矩阵交集的面积。假设计算可得矩阵A,B
左上角和右下角的坐标分别为:
A: (x_A_min, y_A_min, x_A_max, y_A_max), B: (x_B_min, y_B_min, x_B_max, y_B_min)
如果A, B
相交则可以得到矩阵A, B
交集左上角和右下角的坐标分别如下:
x_inter_min = max(x_A_min, x_B_min)
y_inter_min = max(y_A_min, y_B_min)
x_inter_max = min(x_A_max, x_B_max)
y_inter_max = min(y_A_max, y_B_max)
width_inter = x_inter_max - x_inter_min
height_inter = y_inter_max - y_inter_min
需要注意的是,当矩阵A, B
不相交时,则可能出现x_inter_max
小于x_inter_min
或者y_inter_max
小于y_inter_min
的情况。图示如下:
为了避免计算出现负数,则可以使用max
函数将不相交情况下的宽度和高度变为0
,即:
width_inter = max(0, x_inter_max - x_inter_min)
height_inter = max(0, y_inter_max - y_inter_min)
最终,可以实现交并比的计算函数如下:
def IoU(det, i, j):
# 提取 bounding box 信息
x = det[[i, j], 0]
y = det[[i, j], 1]
w = det[[i, j], 2]
h = det[[i, j], 3]
x_min = x - w / 2
x_max = x + w / 2
y_min = y - h / 2
y_max = y + h / 2
size = w * h
inter_x_min = max(x_min)
inter_x_max = min(x_max)
inter_y_min = max(y_min)
inter_y_max = min(y_max)
inter_size = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
return inter_size / (sum(size) - inter_size)
根据提供的伪代码,可以实现非极大抑制算法如下:
def nms(det, theta):
prob = det[:, 4]
index = prob.argsort().tolist() # 剩余索引, 按可能性从小到大排序
keep = [] # 需要保留的索引
while index:
# 将可能性最高的留下
i = index[-1]
index.remove(i)
keep.append(i)
delete = []
for j in index:
if IoU(det, i, j) >= theta:
delete.append(j)
for j in delete:
index.remove(j)
return det[keep]
def IoU(det, i, j):
# 提取 bounding box 信息
x = det[[i, j], 0]
y = det[[i, j], 1]
w = det[[i, j], 2]
h = det[[i, j], 3]
x_min = x - w / 2
x_max = x + w / 2
y_min = y - h / 2
y_max = y + h / 2
size = w * h
inter_x_min = max(x_min)
inter_x_max = min(x_max)
inter_y_min = max(y_min)
inter_y_max = min(y_max)
inter_size = max(0, inter_x_max - inter_x_min) * max(0, inter_y_max - inter_y_min)
return inter_size / (sum(size) - inter_size)
# det: [[x_center, y_center, width, height, classes ... ], ...]
# theta: IoU 阙值
def nms(det, theta):
prob = det[:, 4]
index = prob.argsort().tolist() # 剩余索引, 按可能性从小到大排序
keep = [] # 需要保留的索引
while index:
# 将可能性最高的留下
i = index[-1]
index.remove(i)
keep.append(i)
delete = []
for j in index:
if IoU(det, i, j) >= theta:
delete.append(j)
for j in delete:
index.remove(j)
return det[keep]
if __name__ == "__main__":
import cv2
import torch
img = cv2.imread("1.png")
det = torch.tensor([[80, 280, 30, 40, 0.9],
[82, 278, 32, 45, 0.8],
[77, 281, 30, 38, 0.6],
[260, 270, 30, 60, 0.7],
[254, 273, 34, 62, 0.8]])
for d in det:
img = cv2.rectangle(img, (d[0] - d[2] / 2, d[1] - d[3] / 2), (d[0] + d[2] / 2, d[1] + d[3] / 2), (255, 255, 0), 1)
cv2.imwrite("1_.png", img)
img = cv2.imread("1.png")
for d in nms(det, 0.5):
img = cv2.rectangle(img, (d[0] - d[2] / 2, d[1] - d[3] / 2), (d[0] + d[2] / 2, d[1] + d[3] / 2), (255, 255, 0), 1)
cv2.imwrite("1_nms.png", img)