NMS 非极大值抑制原理及实现

非极大值抑制(Non-Maximum Suppression)

在物体检测领域中,最后预测的结果一般有很多,其中会有一部分预测框重叠在一起。NMS的操作就是去除重叠的预测框,一般是物体检测任务中检测的最后一步处理操作。

如上图所示,可以看到图片中共有两个狗和一个猫,但是根据模型预测一共有3个狗和2个猫。并且我们人类可以很直观地发现有几个框重复预测同一个物体,那么该如何使用算法让计算机去除掉重复的框呢?

NMS过程解析

假设我们已经有了预测的框,每个预测框对应的类别,每个预测框对应的类别得分。

  1. 第一步,对于每个类别,按照类别得分的从大到小顺序排列:

  2. 第二步,对于每个类别,计算得分最大的框与其余预测框之间的IoU
    关于IoU的说明,详见 IOU(Jaccard系数)概念及实现
    预测框与自身的IoU为1,但是不用考虑,因此我们用 - 划掉
    NMS 非极大值抑制原理及实现_第1张图片

  3. 第三步,根据设定的阈值,剔除掉IOU大于阈值的预测框
    NMS 非极大值抑制原理及实现_第2张图片

  4. 第四步,对下一个没被去除的得分最大的框重复2,3操作,直到所有没被抑制的框遍历完毕。
    NMS 非极大值抑制原理及实现_第3张图片
    经过NMS操作后,我们得到了一个预测更好的结果

Pytorch代码实现

def nms(nms_threshold, pred_bbox, pred_classes, pred_scores):
    """
    对预测结果进行非极大值抑制

    :param nms_threshold: 0~1之间的浮点数,最大重叠程度,IoU高于此阈值的框将会被抑制
    :param pred_bbox: torch.FloatTensor类型,形状为 [num_box, 4],
      4表示边界坐标,即 $(x_{min}, y_{min}, x_{max}, y_{max})$
    :param pred_classes: torch.LongTensor类型,形状为 [num_box],表示预测框的类别
    :param pred_scores: torch.FloatTensor类型,形状为 [num_box],表示预测框的类别得分,其中的值范围在 [0, 1]
    :return: 经过非极大值抑制后的预测框、对应类别和对应得分
    """
    # 首先按照得分进行排序,并将预测框和预测类别保持相对应的关系
    pred_scores, sort_index = pred_scores.sort(dim=0, descending=True)
    pred_bbox = pred_bbox[sort_index]
    pred_classes = pred_classes[sort_index]
    # 全局记录需要被抑制的预测框
    suppress = torch.zeros_like(pred_classes, dtype=torch.uint8, device=pred_classes.device)
    # 对每种类别进行处理
    for c in torch.unique(pred_classes):
        index = pred_classes == c  # 获取属于该类别的索引
        bbox = pred_bbox[index, :]  # 提取属于该类别的预测框
        # 计算该类别预测框的IoU
        overlap = find_jaccard_overlap(bbox, bbox)
        # 记录属于该类别预测框的抑制结果
        mask = torch.zeros(bbox.shape[0], dtype=torch.uint8, device=pred_classes.device)
        for b in range(bbox.shape[0]):
            # 如果该框已经被标记为抑制,则不进行处理
            if mask[b] == 1:
                continue
            # 将IoU大于阈值的标记为抑制
            mask = torch.maximum(mask, (overlap[b, :] > nms_threshold).to(torch.uint8))
            # 自己与自己的IoU为1,但是不应该被自己抑制
            mask[b] = 0
        # 更新全局记录
        suppress[index] = mask

    # 处理完所有类别后,将会得到所有框的抑制结果,需要抑制的框对应位置为1,反之为0
    # 将结果取反,获取不需要被抑制的布尔索引
    suppress = (1 - suppress).to(torch.bool)
    # 根据布尔索引提取最终结果
    pred_bbox = pred_bbox[suppress]
    pred_classes = pred_classes[suppress]
    pred_scores = pred_scores[suppress]
    # 如果没有结果符合,则返回None
    if len(pred_classes) == 0:
        return None
    return pred_bbox, pred_classes, pred_scores

你可能感兴趣的:(#,物体检测,深度学习,计算机视觉,目标检测)