NMS详解及pytorch实现:hard-nms(diou\overlap\merge\batched),soft-nms

文章目录

  • NMS详解及pytorch实现:hard-nms(diou\overlap\merge\batched),soft-nms
    • 1 简介
    • 2 原理
    • 3 实现
      • 3.1 伪代码
      • 3.2 pytorch源码
      • 3.3 知识点
    • 参考资料

NMS详解及pytorch实现:hard-nms(diou\overlap\merge\batched),soft-nms

1 简介

非极大值抑制算法(Non-maximum suppression, NMS)是有anchor系列目标检测的标配,如今大部分的One-StageTwo-Stage算法在推断(Inference)阶段都使用了NMS作为网络的最后一层,例如YOLOv3、SSD、Faster-RCNN等。

当然NMS在目前最新的anchor-free(https://www.cnblogs.com/yumoye/p/11022800.html)目标检测算法中(CornerNet、CenterNet等)并不是必须的,对这种检测算法提升的精度也有限,但是这并不影响我们学习NMS。

NMS的本质是搜索局部极大值,抑制非极大值元素,在目标检测中,我们经常将其用于消除多余的检测框(从左到右消除了重复的检测框,只保留当前最大confidence的检测框):

NMS详解及pytorch实现:hard-nms(diou\overlap\merge\batched),soft-nms_第1张图片

NMS有很多种变体,最为常见的Hard-NMS,我们通常所说的NMS就是指Hard-NMS,还有另外一种NMS叫做Soft-NMS,是Hard-NMS的变体,还有一些其他的一些变体(batched\diou\or\and\merge-nms)。

2 原理

最为常见的,也就是咱们提到的nms及为hard-nms,所以这里将以hard-nms入手,剖析内部操作原理。

  • 选取当前类别box中scores最大的那一个,记为current_box,并保留它(为什么保留它,因为它预测出当前位置有物体的概率最大啊,对于我们来说当前confidence越大说明当前box中包含物体的可能行就越大)
  • 计算current_box与其余的box的IOU
  • 如果其IOU大于我们设定的阈值,那么就舍弃这些boxes(由于可能这两个box表示同一目标,因此这两个box的IOU就比较大,会超过我们设定的阈值,所以就保留分数高的那一个)
  • 从最后剩余的boxes中,再找出最大scores的那一个(之前那个大的已经保存到输出的数组中,这个是从剩下的里面再挑一个最大的),如此循环往复

3 实现

3.1 伪代码

NMS详解及pytorch实现:hard-nms(diou\overlap\merge\batched),soft-nms_第2张图片

各种nms特点一句话总结:

Hard-nms–直接删除相邻的同类别目标,密集目标的输出不友好。

Soft-nms–改变其相邻同类别目标置信度(有关iou的函数),后期通过置信度阈值进行过滤,适用于目标密集的场景。

or-nms–hard-nms的非官方实现形式,只支持cpu。

vision-nms–hard-nms的官方实现形式(c函数库),可支持gpu(cuda),只支持单类别输入。

vision-batched-nms–hard-nms的官方实现形式(c函数库),可支持gpu(cuda),支持多类别输入。

and-nms–在hard-nms的逻辑基础上,增加是否为单独框的限制,删除没有重叠框的框(减少误检)。

merge-nms–在hard-nms的基础上,增加保留框位置平滑策略(重叠框位置信息求解平均值),使框的位置更加精确。

diou-nms–在hard-nms的基础上,用diou替换iou,里有参照diou的优势。

3.2 pytorch源码

nms实现函数:

def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.5, multi_cls=True, method='vision'):
    """
    Removes detections with lower object confidence score than 'conf_thres'
    Non-Maximum Suppression to further filter detections.
    Returns detections with shape:
        (x1, y1, x2, y2, object_conf, conf, class)
    """
    # NMS methods https://github.com/ultralytics/yolov3/issues/679 'or', 'and', 'merge', 'vision', 'vision_batch'

    # Box constraints
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximium box width and height

    output = [None] * len(prediction)
    for image_i, pred in enumerate(prediction):
        # Apply conf constraint
        pred = pred[pred[:, 4] > conf_thres]

        # Apply width-height constraint
        pred = pred[(pred[:, 2:4] > min_wh).all(1) & (pred[:, 2:4] < max_wh).all(1)]

        # If none remain process next image
        if len(pred) == 0:
            continue

        # Compute conf
        torch.sigmoid_(pred[..., 5:])
        pred[..., 5:] *= pred[..., 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(pred[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_cls or conf_thres < 0.01:
            i, j = (pred[:, 5:] > conf_thres).nonzero().t()
            pred = torch.cat((box[i], pred[i, j + 5].unsqueeze(1), j.float().unsqueeze(1)), 1)
        else:  # best class only
            conf, j = pred[:, 5:].max(1)
            pred = torch.cat((box, conf.unsqueeze(1), j.float().unsqueeze(1)), 1)

        # Apply finite constraint
        pred = pred[torch.isfinite(pred).all(1)]

        # Get detections sorted by decreasing confidence scores
        pred = pred[pred[:, 4].argsort(descending=True)]

        # Batched NMS
        if method == 'vision_batch':
            output[image_i] = pred[torchvision.ops.boxes.batched_nms(pred[:, :4], pred[:, 4], pred[:, 5], nms_thres)]
            continue

        # All other NMS methods
        det_max = []
        cls = pred[:, -1]
        for c in cls.unique():
            dc = pred[cls == c]  # select class c
            n = len(dc)
            if n == 1:
                det_max.append(dc)  # No NMS required if only 1 prediction
                continue
            elif n > 500:
                dc = dc[:500]  # limit to first 500 boxes: https://github.com/ultralytics/yolov3/issues/117

            if method == 'vision':
                det_max.append(dc[torchvision.ops.boxes.nms(dc[:, :4], dc[:, 4], nms_thres)])

            elif method == 'or':  # default
                while dc.shape[0]:
                    det_max.append(dc[:1])  # save highest conf detection
                    if len(dc) == 1:  # Stop if we're at the last detection
                        break
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    dc = dc[1:][iou < nms_thres]  # remove ious > threshold

            elif method == 'and':  # requires overlap, single boxes erased
                while len(dc) > 1:
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    if iou.max() > 0.5:
                        det_max.append(dc[:1])
                    dc = dc[1:][iou < nms_thres]  # remove ious > threshold

            elif method == 'merge':  # weighted mixture box
                while len(dc):
                    if len(dc) == 1:
                        det_max.append(dc)
                        break
                    i = bbox_iou(dc[0], dc) > nms_thres  # iou with other boxes
                    weights = dc[i, 4:5]
                    dc[0, :4] = (weights * dc[i, :4]).sum(0) / weights.sum()
                    det_max.append(dc[:1])
                    dc = dc[i == 0]
			elif method == 'diounms':  # use diou 
                while dc.shape[0]:
                    det_max.append(dc[:1])  # save highest conf detection
                    if len(dc) == 1:  # Stop if we're at the last detection
                        break
                    diou = bbox_iou(dc[0], dc[1:],DIoU=True)  # diou with other boxes
                    dc = dc[1:][diou < nms_thres]  # remove dious > threshold
            
            elif method == 'soft':  # soft-NMS https://arxiv.org/abs/1704.04503
                sigma = 0.5  # soft-nms sigma parameter
                while len(dc):
                    if len(dc) == 1:
                        det_max.append(dc)
                        break
                    det_max.append(dc[:1])
                    iou = bbox_iou(dc[0], dc[1:])  # iou with other boxes
                    dc = dc[1:]
                    dc[:, 4] *= torch.exp(-iou ** 2 / sigma)  # decay confidences
                    dc = dc[dc[:, 4] > conf_thres]  # https://github.com/ultralytics/yolov3/issues/362

        if len(det_max):
            det_max = torch.cat(det_max)  # concatenate
            output[image_i] = det_max[(-det_max[:, 4]).argsort()]  # sort

    return output

iou计算函数:

def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False):
    # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
    box2 = box2.t()

    # Get the coordinates of bounding boxes
    if x1y1x2y2:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
    else:  # x, y, w, h = box1
        b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
        b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
        b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
        b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2

    # Intersection area
    inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
            (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)

    # Union Area
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
    union = (w1 * h1 + 1e-16) + w2 * h2 - inter

    iou = inter / union  # iou
    if GIoU or DIoU or CIoU:
        cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)  # convex (smallest enclosing box) width
        ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)  # convex height
        if GIoU:  # Generalized IoU https://arxiv.org/pdf/1902.09630.pdf
            c_area = cw * ch + 1e-16  # convex area
            return iou - (c_area - union) / c_area  # GIoU
        if DIoU or CIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            # convex diagonal squared
            c2 = cw ** 2 + ch ** 2 + 1e-16
            # centerpoint distance squared
            rho2 = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2)) ** 2 / 4 + ((b2_y1 + b2_y2) - (b1_y1 + b1_y2)) ** 2 / 4
            if DIoU:
                return iou - rho2 / c2  # DIoU
            elif CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
                with torch.no_grad():
                    alpha = v / (1 - iou + v)
                return iou - (rho2 / c2 + v * alpha)  # CIoU

    return iou

pytorch外接c函数库:

#多类别nms时,需要分类别灌入。
#调用方式:torchvision.ops.boxes.nms()
def nms(boxes, scores, iou_threshold):
    """
    Performs non-maximum suppression (NMS) on the boxes according
    to their intersection-over-union (IoU).

    NMS iteratively removes lower scoring boxes which have an
    IoU greater than iou_threshold with another (higher scoring)
    box.

    Parameters
    ----------
    boxes : Tensor[N, 4])
        boxes to perform NMS on. They
        are expected to be in (x1, y1, x2, y2) format
    scores : Tensor[N]
        scores for each one of the boxes
    iou_threshold : float
        discards all overlapping
        boxes with IoU < iou_threshold

    Returns
    -------
    keep : Tensor
        int64 tensor with the indices
        of the elements that have been kept
        by NMS, sorted in decreasing order of scores
    """
    _C = _lazy_import()
    return _C.nms(boxes, scores, iou_threshold)
#多类别nms时,不需要分类别的灌入,内部已通过idxs区分实现。
#调用方式:torchvision.ops.boxes.batched_nms()
def batched_nms(boxes, scores, idxs, iou_threshold):
    """
    Performs non-maximum suppression in a batched fashion.

    Each index value correspond to a category, and NMS
    will not be applied between elements of different categories.

    Parameters
    ----------
    boxes : Tensor[N, 4]
        boxes where NMS will be performed. They
        are expected to be in (x1, y1, x2, y2) format
    scores : Tensor[N]
        scores for each one of the boxes
    idxs : Tensor[N]
        indices of the categories for each one of the boxes.
    iou_threshold : float
        discards all overlapping boxes
        with IoU < iou_threshold

    Returns
    -------
    keep : Tensor
        int64 tensor with the indices of
        the elements that have been kept by NMS, sorted
        in decreasing order of scores
    """
    if boxes.numel() == 0:
        return torch.empty((0,), dtype=torch.int64, device=boxes.device)
    # strategy: in order to perform NMS independently per class.
    # we add an offset to all the boxes. The offset is dependent
    # only on the class idx, and is large enough so that boxes
    # from different classes do not overlap
    max_coordinate = boxes.max()
    offsets = idxs.to(boxes) * (max_coordinate + 1)
    boxes_for_nms = boxes + offsets[:, None]
    keep = nms(boxes_for_nms, scores, iou_threshold)
    return keep

3.3 知识点

1.nms的gpu版本实现:

如有需求请参考:https://blog.csdn.net/qq_21368481/article/details/85722590

2.nms的应用范围:

只应用在前向推理的过程中,在训练中不进行此步。

3.以上几种nms的性能表现:

https://github.com/ultralytics/yolov3/issues/679

Speed mm:ss COCO mAP @0.5…0.95 COCO mAP @0.5
ultralytics 'OR' 8:20 39.7 60.3
ultralytics 'AND' 7:38 39.6 60.1
ultralytics 'SOFT' 12:00 39.1 58.7
ultralytics 'MERGE' 11:25 40.2 60.4
torchvision.ops.boxes.nms() 5:08 39.7 60.3
torchvision.ops.boxes.batched_nms() 6:00 39.7 60.3

参考资料

https://oldpan.me/archives/write-hard-nms-c

https://github.com/ultralytics/yolov3

你可能感兴趣的:(深度学习知识)