YOLOv5——NMS

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
                        labels=()):
    """Runs Non-Maximum Suppression (NMS) on inference results

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    nc = prediction.shape[2] - 5  # 分类数
    # 第四个值框置信度大于conf_thres的为True,否则为False
    xc = prediction[..., 4] > conf_thres  # candidates 

    # Settings
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    max_det = 300  # maximum number of detections per image
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 10.0  # seconds to quit after
    redundant = True  # require redundant detections
    # 是否属于多分类问题
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    # 默认是关闭的,使用的话需要修改为True
    merge = False  # use merge-NMS

    t = time.time()
    # 创建一个存储容器
    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    # xi表示某张图片的索引
    # x表示某张图片的tensor张量数据,x中包含数条预测框数据
    for xi, x in enumerate(prediction):  # image index, image inference
        # xc[xi]:表示筛选某张图片框置信度大于阈值的所有数据,这里的表示形式是True或False,而不是数据
        # x表示所有的candidates 为True的数据,这里表示筛选出所有置信度大于conf_thres的框的数据
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        # 好像没用到,暂时不管了
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        # 见2讲解
        # 计算最后六个分类置信度信息,将他们分别乘上第四个数据(框置信度)
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        # 见3讲解
        if multi_label:
        	# 如果是多分类问题,得到每一条数据的索引值i,以及每一条数据中分类置信度大于阈值的类别的索引
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            # 得到一个最终x,x是所有符合条件的框信息
            # 某条数据的框坐标信息和这个类别的得分和这个类别的类别索引放在一条数据
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        # 如果大于最大容量的框数量,进行排序,取前max_nms个
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        # 每一个框以及他的分数
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        # NMS过滤后的bouding boxes索引(降序排列)
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            print(f'WARNING: NMS time limit {time_limit}s exceeded')
            break  # time limit exceeded

    return output

1

  • 输入prediction是一个tensor张量;
  • 每一条数据代表每一个预测框;
  • 前面四个数据代表预测框的坐标信息;
  • 第五个数据代表框的置信度;
  • 后面有六个数据,是因为我这里有六个类别的物体,因此每个数据代表每一类的分类置信度。
tensor([[[3.8242e+00, 3.9844e+00, 9.9609e+00, 1.0477e+01, 2.0862e-06,
          9.3384e-03, 3.2898e-02, 2.3975e-03, 6.5369e-02, 8.0322e-01,
          4.5258e-02],
         [1.1172e+01, 3.0117e+00, 2.1844e+01, 7.0781e+00, 4.4703e-06,
          9.1629e-03, 3.9490e-02, 2.4433e-03, 6.0974e-02, 7.5830e-01,
          5.0812e-02],
          ...
          ...
         [5.8500e+02, 6.4350e+02, 1.4875e+02, 9.3812e+01, 1.7881e-06,
          2.4658e-02, 2.3331e-02, 9.5978e-03, 5.4492e-01, 1.3086e-01,
          5.5939e-02],
         [6.2100e+02, 6.5300e+02, 1.1762e+02, 1.0444e+02, 1.4901e-06,
          3.0045e-02, 3.2715e-02, 1.2238e-02, 3.6963e-01, 2.3193e-01,
          5.3101e-02]]], device='cuda:0', dtype=torch.float16)
  • prediction.shape[0] = 1:一张图片就是1

  • prediction.shape[1] = 26460:这里可能表示预测框的数量

  • prediction.shape[2] = 11:每一条数据中数据的个数(4个位置信息+1个框置信度+6个分类置信度)

  • 我们看第一条数据

    [3.8242e+00, 3.9844e+00, 9.9609e+00, 1.0477e+01, 2.0862e-06, 9.3384e-03, 3.2898e-02, 2.3975e-03, 6.5369e-02, 8.0322e-01, 4.5258e-02]
    
    [center_x, center_y, width, height, cls_conf, obj_conf0, obj_conf1, obj_conf2, obj_conf3, obj_conf4, obj_conf5]
    

2

import torch

x = torch.tensor([[2.9950e+02, 2.9025e+02, 1.6962e+02, 3.0675e+02,
                   2.3785e-03,
                   2.0580e-03, 6.3610e-04, 9.6226e-04, 9.9756e-01, 3.0270e-03, 3.2864e-03]],
                   dtype=torch.float16)
print(x[:, 5:])
print(x[:, 4:5])
x[:, 5:] *= x[:, 4:5]
print(x[:, 5:])

输出

tensor([[2.0580e-03, 6.3610e-04, 9.6226e-04, 9.9756e-01, 3.0270e-03, 3.2864e-03]], dtype=torch.float16)
tensor([[0.0024]], dtype=torch.float16)
tensor([[4.8876e-06, 1.4901e-06, 2.2650e-06, 2.3727e-03, 7.2122e-06, 7.8082e-06]], dtype=torch.float16)

print(x[:, 5:]):表示第五个数据后面的数
print(x[:, 4:5]):表示第四个数据的近似值

3

# Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

x[:, 5:] > conf_thres:每个分类置信度大于conf_thres为True,否则False

tensor([[False, True, False,  False, False, False],
        [False, False, False,  False, False, False],
        [False, False, False,  True, False, False],
        [False, False, False,  True, False, False],
        [False, False, False,  True, False, True],
        [False, False, False,  True, False, False]], device='cuda:0')

i:

tensor([0, 2, 3, 4, 4, 5], device='cuda:0')
tensor([0条数据有1True,2条数据有1True,3条数据有1True,4条数据有2_1个True,4条数据有2_2个True,5条数据有1True], device='cuda:0')

j:

tensor([1, 3, 3, 3, 5, 3], device='cuda:0')
tensor([0条数据第1个位置为True,2条数据第3个位置为True,3条数据第3个位置为True,4条数据第3个位置为True,4条数据第5个位置为True,5条数据第3个位置为True, device='cuda:0')

j其实就是类别索引

你可能感兴趣的:(目标检测,目标检测)