当进行预测时,将一张包含待检测目标的图片送入网络,会对同一个目标生成很多的预测框,我们只需要保留将该目标完美标出来的检测框即可,这时就会采用NMS算法来剔除我们不需要的框。
非极大值抑制(Non-Maximum Suppression,简称NMS)的思想是搜索局部极大值,抑制非极大值元素。
nms的过程如下:
1 所有检测框按置信度从高到低排序
2 取当前置信度最高的框,然后删除和这个框的iou高于置信度阈值的框
3 重复第2步直到所有框处理完。
简单说,就是每一次都筛选出每一类里面的得分最大的预测框,然后判断该类的其他预测框框的重合程度,如果重合程度过高(即IOU的值大于所设置的阈值)就剔除,相当于保留一定区域内同一种类得分最大的预测框
1、假设我们预测结果经过decode、置信度筛选等操作得到detections(torch.Size([num_anchors, 7]))
num_anchors:anchors的数量
7:(x1, y1, x2, y2 obj_conf, class_conf, class_pred)
x1, y1, x2, y2:左上横,纵,右下横,纵
obj_conf :预测框内部是否包含物体的置信度,
class_conf :预测框属于某一个种类的置信度
class_pred :预测框属于的某一个种类
# detections [num_anchors, 7]
# 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
2、pytorch官方实现的batched_nms
nms_out_index = boxes.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thres,
)
output[i] = detections[nms_out_index]
def batched_nms(
boxes: Tensor,
scores: Tensor,
idxs: Tensor,
iou_threshold: float,
) -> Tensor:
"""
以批处理方式执行NMS。
每个索引值对应于一个类别,将不会应用于不同类别的元素之间。
Args:
boxes(Tensor[N, 4]): 预测框左上角和右下角坐标
``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and``0 <= y1 < y2``.
scores(Tensor[N]): 每一个预测框的得分
idxs(Tensor[N]): 每个预测框框的类别索引。
iou_threshold (float): 丢弃所有IoU>IoU_阈值的重叠框
Returns:
keep (Tensor): 具有NMS保留的元素索引的int64张量,按分数递减顺序排序
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
"""
# 为每个类独立执行NMS。
# 我们为所有框添加一个偏移量。
# 偏移量仅取决于类idx,并且足够大,以便来自不同类的框不会重叠
"""
else:
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
keep = nms(boxes_for_nms, scores, iou_threshold)
return keep
3、重新实现
# 获得预测结果中包含的所有种类
unique_labels = detections[:, -1].cpu().unique()
for cls in unique_labels:
# 按照存在物体的置信度排序
_, conf_sort_index = torch.sort(detections_class[:, 4] * detections_class[:, 5], descending=True)
detections_class = detections_class[conf_sort_index]
# 进行非极大抑制
max_detections = []
while detections_class.size(0):
# 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
max_detections.append(detections_class[0].unsqueeze(0))
if len(detections_class) == 1:
break
ious = bbox_iou(max_detections[-1], detections_class[1:])
detections_class = detections_class[1:][ious < nms_thres]
# 堆叠
max_detections = torch.cat(max_detections).data
# Add max detections to outputs
output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
nms_out_index = boxes.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thres,
)
output[i] = detections[nms_out_index]
# #------------------------------------------#
# # 获得预测结果中包含的所有种类
# #------------------------------------------#
# unique_labels = detections[:, -1].cpu().unique()
# if prediction.is_cuda:
# unique_labels = unique_labels.cuda()
# detections = detections.cuda()
# for c in unique_labels:
# #------------------------------------------#
# # 获得某一类得分筛选后全部的预测结果
# #------------------------------------------#
# detections_class = detections[detections[:, -1] == c]
#
# #------------------------------------------#
# # 使用官方自带的非极大抑制会速度更快一些!
# #------------------------------------------#
# keep = nms(
# detections_class[:, :4],
# detections_class[:, 4] * detections_class[:, 5],
# nms_thres
# )
# max_detections = detections_class[keep]
#
# # # 按照存在物体的置信度排序
# # _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
# # detections_class = detections_class[conf_sort_index]
# # # 进行非极大抑制
# # max_detections = []
# # while detections_class.size(0):
# # # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
# # max_detections.append(detections_class[0].unsqueeze(0))
# # if len(detections_class) == 1:
# # break
# # ious = bbox_iou(max_detections[-1], detections_class[1:])
# # detections_class = detections_class[1:][ious < nms_thres]
# # # 堆叠
# # max_detections = torch.cat(max_detections).data
#
# # Add max detections to outputs
# output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
#