non_max_suppression
思考box形成过程NMS
解决的最大的问题就是一个目标有多个框的情况下,对同一目标只保留一个框。通常来说,NSM
的阈值越大,保留下的框越多,同一个目标出现多个框的概率越高。但是这个说法不绝对,还是要区分场景和目标尺寸的。
NMS
def non_max_suppression(boxes, conf_thres=0.5, nms_thres=0.3):
detection = boxes
# 1、找出该图片中得分大于门限函数的框。在进行重合框筛选前就进行得分的筛选可以大幅度减少框的数量。
mask = detection[:, 4] >= conf_thres
detection = detection[mask]
if not np.shape(detection)[0]:
return []
best_box = []
scores = detection[:,4]
# 2、根据得分对框进行从大到小排序。
arg_sort = np.argsort(scores)[::-1]
detection = detection[arg_sort]
while np.shape(detection)[0]>0:
# 3、每次取出得分最大的框,计算其与其它所有预测框的重合程度,重合程度过大的则剔除。
best_box.append(detection[0])
if len(detection) == 1:
break
ious = iou(best_box[-1],detection[1:])
detection = detection[1:][ious<nms_thres]
return np.array(best_box)
conf_thres
阈值进行置信度筛选。detection[0:4]
为box坐标,detection[4]
为置信度。detection = detection[arg_sort]
列表。detection[0]
,与其他框detection[1:]
计算iou
,剔除iou
大于nms_thres
阈值的框。保留与其iou
小于nms_thres
阈值的框,进行迭代。NMS
def non_max_suppression(prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
# prediction[0:4]为box坐标,prediction[4]为置信度,prediction[5:5 + num_classes]为每类任务得分
output = [None for _ in range(len(prediction))]
#----------------------------------------------------------#
# 对输入图片进行循环,一般只会进行一次
#----------------------------------------------------------#
for i, image_pred in enumerate(prediction):
# 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
#----------------------------------------------------------#
# 利用置信度进行第一轮筛选
#----------------------------------------------------------#
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
if not image_pred.size(0):
continue
#-------------------------------------------------------------------------#
# detections [num_anchors, 7]
# 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
#-------------------------------------------------------------------------#
detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
detections = detections[conf_mask]
nms_out_index = boxes.batched_nms(
detections[:, :4],
detections[:, 4] * detections[:, 5],
detections[:, 6],
nms_thres,
)
output[i] = detections[nms_out_index]
prediction[0:4]
为box
坐标,prediction[4]
为置信度,prediction[5:5 + num_classes]
为每类任务得分。`torch.max(a,1)`
# 返回每一行中最大值的那个元素,且返回其索引(返回最大元素在这一行的列索引)
class_conf
为torch.max
返回的每行最大值,即得分最大的值。class_pred
为torch.max
返回的每行最大值的索引,即类别。
detections[0:4]
为box坐标,detections[4]
为置信度,detections[5]
为分类得分,detections[6]
为类别索引。
conf_thres
阈值进行分类置信度筛选。torchvision.ops.boxes.batched_nms
进行NMS
。Torchvision.ops.batched_nms()
根据每个类别进行过滤,只对同一种类别进行计算IOU和阈值过滤。
参数:
参数名 | 参数类型 | 说明 |
---|---|---|
boxes |
Tensor | 预测框 |
scores |
Tensor | 预测置信度 |
idxs |
Tensor | 预测框类别 |
iou_threshold |
float | IOU阈值 |
返回:
参数名 | 参数类型 | 说明 |
---|---|---|
keep | Tensor | 预测NMS 过滤后的bouding boxes 索引(降序排列) |
boxes
的形式是(x1,y1,x2,y2)
的格式,其中0 <= x1 < x2
,0 <= y1 < y2
。即左上角右下角的形式。
那么需要将网络输出的中心点和宽高的形式构造成左上角右下角的形式。
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
torch
的NMS
引发的一系列box
形式一般而言,网络输出的坐标,都是中心点和宽高的形式。
比如:
outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
# 归一化
outputs[..., [0,2]] = outputs[..., [0,2]] / input_shape[1]
outputs[..., [1,3]] = outputs[..., [1,3]] / input_shape[0]
pred_boxes[..., 0] = x.data + grid_x
pred_boxes[..., 1] = y.data + grid_y
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
既然网络输出的结果是中心点和宽高的形式,但是labelimg
标注的xml
又是左上角右下角的形式。
所以在训练的时候,应该在获取target
时做了框的坐标转换。通常在get_target
方法中,有如下:
# in_h、in_w是输入图片要求的高和宽
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
batch_target[:, 4] = targets[b][:, 4]
为什么要这么搞而不是直接用左上角右下角坐标呢?
个人觉得还是因为中心点和宽高的形式带有语义信息,这样进行训练的话,会减少网络出错的可能。