NMS(非极大值抑制)算法详解与示例

一、NMS是什么?

NMS(non maximum suppression)即非极大值抑制,广泛应用于传统的特征提取和深度学习的目标检测算法中。
NMS原理是通过筛选出局部极大值得到最优解。
在2维边缘提取中体现在提取边缘轮廓后将一些梯度方向变化率较小的点筛选掉,避免造成干扰。
在三维关键点检测中也起到重要作用,筛选掉特征中非局部极值。
在目标检测方面如Yolo和RCNN等模型中均有使用,可以将较小分数的输出框过滤掉,同样,在三维基于点云的目标检测模型中亦有使用。


二、示例

1.opencv示例

查看opencv源码,可以知道canny算子中使用了nms,再用sobel等梯度计算方法生成的梯度矩阵上求取的极大值。
其计算方法是比较中心点与其邻域的梯度值,如果为最大值,则保留,不是的话为0。
源码可见:
Canny算法解析,opencv源码实现及实例

    //读取图片
    Mat img = imread("true.jpg");
    Mat Grayimg;

    resize(img, img, Size(400, 600), 0, 0, INTER_LINEAR);
    cvtColor(img, Grayimg, COLOR_RGB2GRAY);     //转为灰度图
    Canny(Grayimg, Grayimg, 100, 300, 3);

    imshow("picture0", img);
    imshow("picture", Grayimg);

    waitKey(0);
    return 0;

2.PCL示例

点云关键点特征提取算法经常会使用nms提取极大值点。
如3D SIFT关键点检测中需要计算尺度空间中像素点的26邻域的极值点。
算法原理参考:
PCL 3D-SIFT关键点检测(Z方向梯度约束)

pcl::SIFTKeypoint<pcl::PointXYZ, pcl::PointWithScale> sift;
pcl::PointCloud<pcl::PointWithScale> result;
sift.setInputCloud(cloud_xyz);
pcl::search::KdTree<pcl::PointXYZ>::Ptr tree(new pcl::search::KdTree<pcl::PointXYZ>());
sift.setSearchMethod(tree); 
sift.setScales(0.01f, 7, 20);
sift.setMinimumContrast(0.001f);
sift.compute(result);  

NMS(非极大值抑制)算法详解与示例_第1张图片

3.目标检测中nms示例

nms在深入学习领域常用于对box的得分进行极大值筛选,在rcnn,yolo, pointnet等模型中广泛使用。
其算法流程大致为:
1:计算所有box的得分。
2:排序,依次与得分高的box的IOU进行对比,如果大于设定的阈值,就删除该框。
在yolo源代码detect.py可见:

pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
conf_thres:置信度即得分score的阈值,yolo为0.25。
iou_thres:重叠度阈值,为0.45
classes:类别数,可以设置保留哪一类的box
agnostic_nms:是否去除不同类别之间的框,默认false
max_det:一张图片中最大识别种类的个数,默认300
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
                        labels=(), max_det=300):
    """Runs Non-Maximum Suppression (NMS) on inference results

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    nc = prediction.shape[2] - 5  # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

    # Settings
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 10.0  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            print(f'WARNING: NMS time limit {time_limit}s exceeded')
            break  # time limit exceeded

    return output

NMS(非极大值抑制)算法详解与示例_第2张图片


你可能感兴趣的:(图像算法,算法,c++,图像处理)