一批图像输入SSD深度神经网络之后,输出的是预测的bboxes的偏移量和每一个类别的置信度,需要对这些输出进行处理才能得到最终的预测结果。
对每一张图片的预测值分别进行如下处理:
多类别非极大性抑制:
详细代码:
import torch
def detector(priors_cxcy, predicted_locs, predicted_scores, min_score, max_overlap, top_k):
'''
Params:
priors_cxcy: [8732, 4]
predicted_locs: [N, 8732, 4]
predicted_scores: [N, 8732, num_classes]
'''
result_list = []
for locs, scores in zip(predicted_locs, predicted_scores): #对每一张图片的预测值分别进行处理
result = get_bboxes_single(priors_cxcy, locs, scores, min_score, max_overlap, top_k)
result_list.append(result)
return result_list
def get_bboxes_single(anchors, predicted_locs, predicted_scores, min_score, max_overlap, top_k):
'''
Params:
anchors: [8732, 4]
predicted_locs: [8732, 4]
predicted_scores: [8732, num_classes]
'''
assert anchors.size(0) == predicted_locs.size(0) == predicted_scores.size(0)
scores = predicted_scores.softmax(-1)
max_scores, _ = scores[:, 1:].max(dim=1) #把背景类别的分数去掉,这里的0代表背景
_, topk_inds = max_scores.topk(top_k)
anchors = anchors[topk_inds, :]
predicted_locs = predicted_locs[topk_inds, :]
scores = scores[topk_inds, :]
bboxes = cxcy_to_xy(gcxgcy_to_cxcy(predicted_locs, anchors)) #decode
det_bboxes, det_labels = multiclass_nms(bboxes, scores, min_score, max_overlap)
det_labels = det_labels + 1 #因为之前把背景类别去掉了,要把类别+1才是真正的类别
return det_bboxes, det_labels
def multiclass_nms(multi_bboxes, multi_scores, score_thr, threshold, max_num=-1):
'''
Params:
multi_bboxes: [n, 4]
multi_scores: [n, num_class]
'''
num_classes = multi_scores.size(1) - 1
bboxes = multi_bboxes[:, None].expand(-1, num_classes, 4)
scores = multi_scores[:, 1:]
valid_mask = scores > score_thr
bboxes = bboxes[valid_mask]
scores = scores[valid_mask]
labels = valid_mask.nonzero()[:, 1]
if bboxes.numel() == 0:
bboxes = multi_bboxes.new_zeros((0, 5))
labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
return bboxes, labels
dets, keep = batch_nms(bboxes, scores, labels, threshold)
if max_num > 0:
dets = dets[:max_num]
keep = keep[:max_num]
return dets, labels[keep]
def batch_nms(bboxes, scores, inds, threshold):
'''
Params:
bboxes: [n, 4]
scores: [n]
'''
#将不同类别的预测框平移到互不干扰的区域,这样才能在同类预测框之间进行非极大性抑制
max_coordinate = bboxes.max()
offset = inds.to(bboxes) * (max_coordinate + 1)
bboxes_for_nms = bboxes + offset[:, None]
dets, keep = nms(torch.cat([bboxes_for_nms, scores[:, None]], -1), threshold)
bboxes = bboxes[keep]
scores = dets[:, -1]
return torch.cat([bboxes, scores[:, None]], -1), keep
def nms(dets, threshold):
x1 = dets[:, 0]
y1 = dets[:, 1]
x2 = dets[:, 2]
y2 = dets[:, 3]
areas = (y2-y1) * (x2-x1)
scores = dets[:, 4]
keep = []
_, order = scores.sort(0, descending=True)
while order.numel() > 0:
if order.numel() == 1:
i = order.item()
keep.append(i)
break
else:
i = order[0].item()
keep.append(i)
xx1 = x1[order[1:]].clamp(min=x1[i].data)
yy1 = y1[order[1:]].clamp(min=y1[i].data)
xx2 = x2[order[1:]].clamp(max=x2[i].data)
yy2 = y2[order[1:]].clamp(max=y2[i].data)
inter = (xx2-xx1).clamp(min=0) * (yy2-yy1).clamp(min=0)
iou = inter / (areas[i]+areas[order[1:]] - inter)
idx = (iou <= threshold).nonzero().squeeze()
if idx.numel() == 0:
break
order = order[idx+1]
keep = torch.LongTensor(keep)
return dets[keep, :], keep