def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5 # 分类数
# 第四个值框置信度大于conf_thres的为True,否则为False
xc = prediction[..., 4] > conf_thres # candidates
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
# 是否属于多分类问题
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
# 默认是关闭的,使用的话需要修改为True
merge = False # use merge-NMS
t = time.time()
# 创建一个存储容器
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
# xi表示某张图片的索引
# x表示某张图片的tensor张量数据,x中包含数条预测框数据
for xi, x in enumerate(prediction): # image index, image inference
# xc[xi]:表示筛选某张图片框置信度大于阈值的所有数据,这里的表示形式是True或False,而不是数据
# x表示所有的candidates 为True的数据,这里表示筛选出所有置信度大于conf_thres的框的数据
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
# 好像没用到,暂时不管了
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
# 见2讲解
# 计算最后六个分类置信度信息,将他们分别乘上第四个数据(框置信度)
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
# 见3讲解
if multi_label:
# 如果是多分类问题,得到每一条数据的索引值i,以及每一条数据中分类置信度大于阈值的类别的索引
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
# 得到一个最终x,x是所有符合条件的框信息
# 某条数据的框坐标信息和这个类别的得分和这个类别的类别索引放在一条数据
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
# 如果大于最大容量的框数量,进行排序,取前max_nms个
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
# 每一个框以及他的分数
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
# NMS过滤后的bouding boxes索引(降序排列)
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded
return output
prediction
是一个tensor张量;tensor([[[3.8242e+00, 3.9844e+00, 9.9609e+00, 1.0477e+01, 2.0862e-06,
9.3384e-03, 3.2898e-02, 2.3975e-03, 6.5369e-02, 8.0322e-01,
4.5258e-02],
[1.1172e+01, 3.0117e+00, 2.1844e+01, 7.0781e+00, 4.4703e-06,
9.1629e-03, 3.9490e-02, 2.4433e-03, 6.0974e-02, 7.5830e-01,
5.0812e-02],
...
...
[5.8500e+02, 6.4350e+02, 1.4875e+02, 9.3812e+01, 1.7881e-06,
2.4658e-02, 2.3331e-02, 9.5978e-03, 5.4492e-01, 1.3086e-01,
5.5939e-02],
[6.2100e+02, 6.5300e+02, 1.1762e+02, 1.0444e+02, 1.4901e-06,
3.0045e-02, 3.2715e-02, 1.2238e-02, 3.6963e-01, 2.3193e-01,
5.3101e-02]]], device='cuda:0', dtype=torch.float16)
prediction.shape[0] = 1
:一张图片就是1
prediction.shape[1] = 26460
:这里可能表示预测框的数量
prediction.shape[2] = 11
:每一条数据中数据的个数(4个位置信息+1个框置信度+6个分类置信度)
我们看第一条数据
[3.8242e+00, 3.9844e+00, 9.9609e+00, 1.0477e+01, 2.0862e-06, 9.3384e-03, 3.2898e-02, 2.3975e-03, 6.5369e-02, 8.0322e-01, 4.5258e-02]
[center_x, center_y, width, height, cls_conf, obj_conf0, obj_conf1, obj_conf2, obj_conf3, obj_conf4, obj_conf5]
import torch
x = torch.tensor([[2.9950e+02, 2.9025e+02, 1.6962e+02, 3.0675e+02,
2.3785e-03,
2.0580e-03, 6.3610e-04, 9.6226e-04, 9.9756e-01, 3.0270e-03, 3.2864e-03]],
dtype=torch.float16)
print(x[:, 5:])
print(x[:, 4:5])
x[:, 5:] *= x[:, 4:5]
print(x[:, 5:])
输出
tensor([[2.0580e-03, 6.3610e-04, 9.6226e-04, 9.9756e-01, 3.0270e-03, 3.2864e-03]], dtype=torch.float16)
tensor([[0.0024]], dtype=torch.float16)
tensor([[4.8876e-06, 1.4901e-06, 2.2650e-06, 2.3727e-03, 7.2122e-06, 7.8082e-06]], dtype=torch.float16)
print(x[:, 5:])
:表示第五个数据后面的数
print(x[:, 4:5])
:表示第四个数据的近似值
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
x[:, 5:] > conf_thres
:每个分类置信度大于conf_thres为True,否则False
tensor([[False, True, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, True, False, False],
[False, False, False, True, False, False],
[False, False, False, True, False, True],
[False, False, False, True, False, False]], device='cuda:0')
i:
tensor([0, 2, 3, 4, 4, 5], device='cuda:0')
tensor([第0条数据有1个True, 第2条数据有1个True, 第3条数据有1个True, 第4条数据有2_1个True, 第4条数据有2_2个True, 第5条数据有1个True], device='cuda:0')
j:
tensor([1, 3, 3, 3, 5, 3], device='cuda:0')
tensor([第0条数据第1个位置为True, 第2条数据第3个位置为True, 第3条数据第3个位置为True, 第4条数据第3个位置为True, 第4条数据第5个位置为True, 第5条数据第3个位置为True, device='cuda:0')
j其实就是类别索引