所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
模型训练完成后,需要decode模型输出才能进行测试。我们从RetinaNet类进行forward计算后可以得到cls heads和reg heads,但此时reg heads预测的是tx,ty,tw,th,我们需要使用对应的Anchor box坐标将其转换为预测的box坐标。坐标的转换规则就是从零实现RetinaNet(四)中box坐标转换为回归标签tx,ty,tw,th的逆运算。
回归预测转换为box预测的代码实现如下:
def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
self, reg_heads, anchors):
"""
snap reg heads to pred bboxes
reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
"""
anchors_wh = anchors[:, 2:] - anchors[:, :2]
anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh
device = anchors.device
factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
reg_heads = reg_heads * factor
pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr
pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
pred_bboxes = torch.cat(
[pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[anchor_nums,4]
return pred_bboxes
NMS后处理的标准方法是:先将所有候选目标按分类score从大到小排序,记录所有候选目标的分类类别有哪几种。然后开始遍历探测到的这几个类别,对于每个类别,提取出这个类别的所有候选目标(注意因为我们一开始已经排过序了,所以按类别提取出来仍然是有序的),先把第一个目标提取到保留目标集合中,然后计算剩余所有目标与该目标的IoU,IoU大于阈值的候选目标全部抛弃。对于RetinaNet,这个阈值为0.5。然后剩余没有抛弃的目标重复上面过程,继续把第一个目标提取到保留目标集合中,后面操作都是一样的,直到没有候选目标为止,对该类候选目标的NMS就做完了。对所有类别都遍历完,NMS就做完了。
在其他目标检测代码实现中,我发现有许多代码在做NMS后处理时并没有分类别来作NMS(即所有不同类别的候选目标一起作NMS)。因此我也尝试了这种做法,发现这种做法总是比NMS的标准做法要低0.2~0.5个mAP左右,因此,在下面的代码实现中,还是使用NMS的标准方法。
NMS后处理的代码实现如下:
def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
"""
one_image_scores:[anchor_nums],4:classification predict scores
one_image_classes:[anchor_nums],class indexes for predict scores
one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
"""
# Sort boxes
sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
one_image_scores, descending=True)
sorted_one_image_classes = one_image_classes[
sorted_one_image_scores_indexes]
sorted_one_image_pred_bboxes = one_image_pred_bboxes[
sorted_one_image_scores_indexes]
sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
2:] - sorted_one_image_pred_bboxes[:, :
2]
sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
0] * sorted_pred_bboxes_w_h[:,
1]
detected_classes = torch.unique(sorted_one_image_classes, sorted=True)
keep_scores, keep_classes, keep_pred_bboxes = [], [], []
for detected_class in detected_classes:
single_class_scores = sorted_one_image_scores[
sorted_one_image_classes == detected_class]
single_class_pred_bboxes = sorted_one_image_pred_bboxes[
sorted_one_image_classes == detected_class]
single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
sorted_one_image_classes == detected_class]
single_class = sorted_one_image_classes[sorted_one_image_classes ==
detected_class]
single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
while single_class_scores.numel() > 0:
top1_score, top1_class, top1_pred_bbox = single_class_scores[
0:1], single_class[0:1], single_class_pred_bboxes[0:1]
single_keep_scores.append(top1_score)
single_keep_classes.append(top1_class)
single_keep_pred_bboxes.append(top1_pred_bbox)
top1_areas = single_class_pred_bboxes_areas[0]
if single_class_scores.numel() == 1:
break
single_class_scores = single_class_scores[1:]
single_class = single_class[1:]
single_class_pred_bboxes = single_class_pred_bboxes[1:]
single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
1:]
overlap_area_top_left = torch.max(
single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
overlap_area_bot_right = torch.min(
single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
overlap_area_sizes = torch.clamp(overlap_area_bot_right -
overlap_area_top_left,
min=0)
overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
1]
# compute union_area
union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
union_area = torch.clamp(union_area, min=1e-4)
# compute ious for top1 pred_bbox and the other pred_bboxes
ious = overlap_area / union_area
single_class_scores = single_class_scores[
ious < self.nms_threshold]
single_class = single_class[ious < self.nms_threshold]
single_class_pred_bboxes = single_class_pred_bboxes[
ious < self.nms_threshold]
single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
ious < self.nms_threshold]
single_keep_scores = torch.cat(single_keep_scores, axis=0)
single_keep_classes = torch.cat(single_keep_classes, axis=0)
single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
axis=0)
keep_scores.append(single_keep_scores)
keep_classes.append(single_keep_classes)
keep_pred_bboxes.append(single_keep_pred_bboxes)
keep_scores = torch.cat(keep_scores, axis=0)
keep_classes = torch.cat(keep_classes, axis=0)
keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)
return keep_scores, keep_classes, keep_pred_bboxes
有了上面两部分,现在我们可以开始decode解码了。整个decode解码的流程是:先将reg head的tx,ty,tw,th预测转换为box坐标预测(需要使用Anchor坐标信息),然后使用一个分类score阈值过滤到分类分数太低的候选目标,对于RetinaNet,这个阈值是0.05。然后,我们对剩下的候选目标NMS后处理,得到保留的候选目标。最后,我们还设置了一个max_detection_num,即确定最终输出时保留多少个目标,对于COCO数据集,这个值为100,因为COCO数据集的图片上没有单张图片标注了超过100个目标的情况。
decode解码的代码实现如下:
class RetinaDecoder(nn.Module):
def __init__(self,
image_w,
image_h,
min_score_threshold=0.05,
nms_threshold=0.5,
max_detection_num=100):
super(RetinaDecoder, self).__init__()
self.image_w = image_w
self.image_h = image_h
self.min_score_threshold = min_score_threshold
self.nms_threshold = nms_threshold
self.max_detection_num = max_detection_num
def forward(self, cls_heads, reg_heads, batch_anchors):
device = cls_heads[0].device
cls_heads = torch.cat(cls_heads, axis=1)
reg_heads = torch.cat(reg_heads, axis=1)
batch_anchors = torch.cat(batch_anchors, axis=1)
batch_scores, batch_classes, batch_pred_bboxes = [], [], []
for per_image_cls_heads, per_image_reg_heads, per_image_anchors in zip(
cls_heads, reg_heads, batch_anchors):
pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
per_image_reg_heads, per_image_anchors)
scores, score_classes = torch.max(per_image_cls_heads, dim=1)
score_classes = score_classes[
scores > self.min_score_threshold].float()
pred_bboxes = pred_bboxes[
scores > self.min_score_threshold].float()
scores = scores[scores > self.min_score_threshold].float()
single_image_scores = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
single_image_classes = (-1) * torch.ones(
(self.max_detection_num, ), device=device)
single_image_pred_bboxes = (-1) * torch.ones(
(self.max_detection_num, 4), device=device)
if scores.shape[0] != 0:
scores, score_classes, pred_bboxes = self.nms(
scores, score_classes, pred_bboxes)
sorted_keep_scores, sorted_keep_scores_indexes = torch.sort(
scores, descending=True)
sorted_keep_classes = score_classes[sorted_keep_scores_indexes]
sorted_keep_pred_bboxes = pred_bboxes[
sorted_keep_scores_indexes]
final_detection_num = min(self.max_detection_num,
sorted_keep_scores.shape[0])
single_image_scores[
0:final_detection_num] = sorted_keep_scores[
0:final_detection_num]
single_image_classes[
0:final_detection_num] = sorted_keep_classes[
0:final_detection_num]
single_image_pred_bboxes[
0:final_detection_num, :] = sorted_keep_pred_bboxes[
0:final_detection_num, :]
single_image_scores = single_image_scores.unsqueeze(0)
single_image_classes = single_image_classes.unsqueeze(0)
single_image_pred_bboxes = single_image_pred_bboxes.unsqueeze(0)
batch_scores.append(single_image_scores)
batch_classes.append(single_image_classes)
batch_pred_bboxes.append(single_image_pred_bboxes)
batch_scores = torch.cat(batch_scores, axis=0)
batch_classes = torch.cat(batch_classes, axis=0)
batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)
# batch_scores shape:[batch_size,max_detection_num]
# batch_classes shape:[batch_size,max_detection_num]
# batch_pred_bboxes shape[batch_size,max_detection_num,4]
return batch_scores, batch_classes, batch_pred_bboxes
def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
"""
one_image_scores:[anchor_nums],4:classification predict scores
one_image_classes:[anchor_nums],class indexes for predict scores
one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
"""
# Sort boxes
sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
one_image_scores, descending=True)
sorted_one_image_classes = one_image_classes[
sorted_one_image_scores_indexes]
sorted_one_image_pred_bboxes = one_image_pred_bboxes[
sorted_one_image_scores_indexes]
sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
2:] - sorted_one_image_pred_bboxes[:, :
2]
sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
0] * sorted_pred_bboxes_w_h[:,
1]
detected_classes = torch.unique(sorted_one_image_classes, sorted=True)
keep_scores, keep_classes, keep_pred_bboxes = [], [], []
for detected_class in detected_classes:
single_class_scores = sorted_one_image_scores[
sorted_one_image_classes == detected_class]
single_class_pred_bboxes = sorted_one_image_pred_bboxes[
sorted_one_image_classes == detected_class]
single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
sorted_one_image_classes == detected_class]
single_class = sorted_one_image_classes[sorted_one_image_classes ==
detected_class]
single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
while single_class_scores.numel() > 0:
top1_score, top1_class, top1_pred_bbox = single_class_scores[
0:1], single_class[0:1], single_class_pred_bboxes[0:1]
single_keep_scores.append(top1_score)
single_keep_classes.append(top1_class)
single_keep_pred_bboxes.append(top1_pred_bbox)
top1_areas = single_class_pred_bboxes_areas[0]
if single_class_scores.numel() == 1:
break
single_class_scores = single_class_scores[1:]
single_class = single_class[1:]
single_class_pred_bboxes = single_class_pred_bboxes[1:]
single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
1:]
overlap_area_top_left = torch.max(
single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
overlap_area_bot_right = torch.min(
single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
overlap_area_sizes = torch.clamp(overlap_area_bot_right -
overlap_area_top_left,
min=0)
overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
1]
# compute union_area
union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
union_area = torch.clamp(union_area, min=1e-4)
# compute ious for top1 pred_bbox and the other pred_bboxes
ious = overlap_area / union_area
single_class_scores = single_class_scores[
ious < self.nms_threshold]
single_class = single_class[ious < self.nms_threshold]
single_class_pred_bboxes = single_class_pred_bboxes[
ious < self.nms_threshold]
single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
ious < self.nms_threshold]
single_keep_scores = torch.cat(single_keep_scores, axis=0)
single_keep_classes = torch.cat(single_keep_classes, axis=0)
single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
axis=0)
keep_scores.append(single_keep_scores)
keep_classes.append(single_keep_classes)
keep_pred_bboxes.append(single_keep_pred_bboxes)
keep_scores = torch.cat(keep_scores, axis=0)
keep_classes = torch.cat(keep_classes, axis=0)
keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)
return keep_scores, keep_classes, keep_pred_bboxes
def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
self, reg_heads, anchors):
"""
snap reg heads to pred bboxes
reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
"""
anchors_wh = anchors[:, 2:] - anchors[:, :2]
anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh
device = anchors.device
factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)
reg_heads = reg_heads * factor
pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr
pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh
pred_bboxes = torch.cat(
[pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
pred_bboxes = pred_bboxes.int()
pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
max=self.image_w - 1)
pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
max=self.image_h - 1)
# pred bboxes shape:[anchor_nums,4]
return pred_bboxes
这样decode解码部分就实现好了。