非极大抑制算法简介

算法简要分析:

首先将候选的各个框的置信度进行升序排序.
重复一下步骤,直至所有候选框都被遍历完成:
	1. 选出候选框中置信度最高的框,放入结果中待返回,
	2. 剔除所有与当前最高置信度框重叠过高的候选框,
			即IOU大于预先设定的阈值.
最后返回所有满足条件的框,以及其数量.

非极大抑制算法简介_第1张图片

代码展示:

# Original author: Francisco Massa:
# https://github.com/fmassa/object-detection.torch
# Ported to PyTorch by Max deGroot (02/01/2017)
def nms(boxes, scores, overlap=0.5, top_k=200):
    # boxes  torch.Size([11, 4])
    # scores  torch.Size([11])
    # overlap  0.45
    # top_k  200
    keep = scores.new(scores.size(0)).zero_().long()  # torch.Size([11]) 用于记录该保留下的框的索引
    if boxes.numel() == 0:
        return keep
    x1 = boxes[:, 0]  # xmin
    y1 = boxes[:, 1]  # ymin
    x2 = boxes[:, 2]  # xmax
    y2 = boxes[:, 3]  # ymax
    area = torch.mul(x2 - x1, y2 - y1)  # 所有框的面积
    v, idx = scores.sort(0)  # 对置信度进行升序排序
    idx = idx[-top_k:]  # 获得前top_k个最高置信度的索引
    # xx1 = boxes.new() 
    xx1 = boxes.new()  # 其实这里可以不用新创建,只需要后续改成 xx1 = torch.index_select(x1, 0, idx)即可
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    count = 0 # 用于记录筛选出来的框的数量
    while idx.numel() > 0:
        i = idx[-1]  # 获得剩余所有框中置信度最高的框的索引
        keep[count] = i  # 选出当前剩余中的框置信度最高的框的索引
        count += 1  # 已选出的框的数量增加1
        if idx.size(0) == 1:  # 所有框都已经遍历完毕
            break
        idx = idx[:-1]  # 置信度最高的索引从待选择部分移除
        torch.index_select(x1, 0, idx, out=xx1)
        # xx1 = torch.index_select(x1, 0, idx) # 使用这行代码的话前面就不需要新创建
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        xx1 = torch.clamp(xx1, min=x1[i])  # 获得矩形重合部分的xmin
        yy1 = torch.clamp(yy1, min=y1[i])  # 获得矩形重合部分的ymin
        xx2 = torch.clamp(xx2, max=x2[i])  # 获得矩形重合部分的xmax
        yy2 = torch.clamp(yy2, max=y2[i])  # 获得矩形重合部分的ymax
        w.resize_as_(xx2)  # 这两行代码多余
        h.resize_as_(yy2)  # 这两行代码多余
        w = xx2 - xx1  # 获得矩形重合部分的宽度
        h = yy2 - yy1  # 获得矩形重合部分的高度
        w = torch.clamp(w, min=0.0)  # 这两行代码用于处理不重叠的情形
        h = torch.clamp(h, min=0.0)  # 这两行代码用于处理不重叠的情形
        inter = w*h  # 计算和当前置信度最高框的重叠部分面积
        rem_areas = torch.index_select(area, 0, idx)  # 计算这些剩余部分框各自的面积
        union = (rem_areas - inter) + area[i]  # 计算这些剩余部分框各自和当前置信度最高框合并之后的面积
        IoU = inter/union  # 计算剩余所有框各自与当前最高置信度框的交并比
        idx = idx[IoU.le(overlap)]  # 移除所有重叠过高的框,保留低重叠的框
    return keep, count  # 非极大抑制操作之后该保留的框的索引下标,以及该保留的这些框的数量

你可能感兴趣的:(非极大抑制算法简介)