【庖丁解牛】从零实现RetinaNet(五):回归预测转换、NMS后处理、decode解码

文章目录

  • 回归预测转换
  • NMS后处理
  • decode解码

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

回归预测转换

模型训练完成后,需要decode模型输出才能进行测试。我们从RetinaNet类进行forward计算后可以得到cls heads和reg heads,但此时reg heads预测的是tx,ty,tw,th,我们需要使用对应的Anchor box坐标将其转换为预测的box坐标。坐标的转换规则就是从零实现RetinaNet(四)中box坐标转换为回归标签tx,ty,tw,th的逆运算。

回归预测转换为box预测的代码实现如下:

    def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
            self, reg_heads, anchors):
        """
        snap reg heads to pred bboxes
        reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
        anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        """
        anchors_wh = anchors[:, 2:] - anchors[:, :2]
        anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh

        device = anchors.device
        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)

        reg_heads = reg_heads * factor

        pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
        pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr

        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh

        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
        pred_bboxes = pred_bboxes.int()

        pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
        pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
        pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
                                        max=self.image_w - 1)
        pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
                                        max=self.image_h - 1)

        # pred bboxes shape:[anchor_nums,4]
        return pred_bboxes

NMS后处理

NMS后处理的标准方法是:先将所有候选目标按分类score从大到小排序,记录所有候选目标的分类类别有哪几种。然后开始遍历探测到的这几个类别,对于每个类别,提取出这个类别的所有候选目标(注意因为我们一开始已经排过序了,所以按类别提取出来仍然是有序的),先把第一个目标提取到保留目标集合中,然后计算剩余所有目标与该目标的IoU,IoU大于阈值的候选目标全部抛弃。对于RetinaNet,这个阈值为0.5。然后剩余没有抛弃的目标重复上面过程,继续把第一个目标提取到保留目标集合中,后面操作都是一样的,直到没有候选目标为止,对该类候选目标的NMS就做完了。对所有类别都遍历完,NMS就做完了。
在其他目标检测代码实现中,我发现有许多代码在做NMS后处理时并没有分类别来作NMS(即所有不同类别的候选目标一起作NMS)。因此我也尝试了这种做法,发现这种做法总是比NMS的标准做法要低0.2~0.5个mAP左右,因此,在下面的代码实现中,还是使用NMS的标准方法。

NMS后处理的代码实现如下:

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
                                                              2:] - sorted_one_image_pred_bboxes[:, :
                                                                                                 2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
                                                          0] * sorted_pred_bboxes_w_h[:,
                                                                                      1]
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
                sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==
                                                    detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    1:]

                overlap_area_top_left = torch.max(
                    single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                overlap_area_bot_right = torch.min(
                    single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                                 overlap_area_top_left,
                                                 min=0)
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
                                                                             1]

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                single_class_scores = single_class_scores[
                    ious < self.nms_threshold]
                single_class = single_class[ious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[
                    ious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    ious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
                                                axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

decode解码

有了上面两部分,现在我们可以开始decode解码了。整个decode解码的流程是:先将reg head的tx,ty,tw,th预测转换为box坐标预测(需要使用Anchor坐标信息),然后使用一个分类score阈值过滤到分类分数太低的候选目标,对于RetinaNet,这个阈值是0.05。然后,我们对剩下的候选目标NMS后处理,得到保留的候选目标。最后,我们还设置了一个max_detection_num,即确定最终输出时保留多少个目标,对于COCO数据集,这个值为100,因为COCO数据集的图片上没有单张图片标注了超过100个目标的情况。
decode解码的代码实现如下:

class RetinaDecoder(nn.Module):
    def __init__(self,
                 image_w,
                 image_h,
                 min_score_threshold=0.05,
                 nms_threshold=0.5,
                 max_detection_num=100):
        super(RetinaDecoder, self).__init__()
        self.image_w = image_w
        self.image_h = image_h
        self.min_score_threshold = min_score_threshold
        self.nms_threshold = nms_threshold
        self.max_detection_num = max_detection_num

    def forward(self, cls_heads, reg_heads, batch_anchors):
        device = cls_heads[0].device
        cls_heads = torch.cat(cls_heads, axis=1)
        reg_heads = torch.cat(reg_heads, axis=1)
        batch_anchors = torch.cat(batch_anchors, axis=1)

        batch_scores, batch_classes, batch_pred_bboxes = [], [], []
        for per_image_cls_heads, per_image_reg_heads, per_image_anchors in zip(
                cls_heads, reg_heads, batch_anchors):
            pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
                per_image_reg_heads, per_image_anchors)
            scores, score_classes = torch.max(per_image_cls_heads, dim=1)
            score_classes = score_classes[
                scores > self.min_score_threshold].float()
            pred_bboxes = pred_bboxes[
                scores > self.min_score_threshold].float()
            scores = scores[scores > self.min_score_threshold].float()

            single_image_scores = (-1) * torch.ones(
                (self.max_detection_num, ), device=device)
            single_image_classes = (-1) * torch.ones(
                (self.max_detection_num, ), device=device)
            single_image_pred_bboxes = (-1) * torch.ones(
                (self.max_detection_num, 4), device=device)

            if scores.shape[0] != 0:
                scores, score_classes, pred_bboxes = self.nms(
                    scores, score_classes, pred_bboxes)

                sorted_keep_scores, sorted_keep_scores_indexes = torch.sort(
                    scores, descending=True)
                sorted_keep_classes = score_classes[sorted_keep_scores_indexes]
                sorted_keep_pred_bboxes = pred_bboxes[
                    sorted_keep_scores_indexes]

                final_detection_num = min(self.max_detection_num,
                                          sorted_keep_scores.shape[0])

                single_image_scores[
                    0:final_detection_num] = sorted_keep_scores[
                        0:final_detection_num]
                single_image_classes[
                    0:final_detection_num] = sorted_keep_classes[
                        0:final_detection_num]
                single_image_pred_bboxes[
                    0:final_detection_num, :] = sorted_keep_pred_bboxes[
                        0:final_detection_num, :]

            single_image_scores = single_image_scores.unsqueeze(0)
            single_image_classes = single_image_classes.unsqueeze(0)
            single_image_pred_bboxes = single_image_pred_bboxes.unsqueeze(0)

            batch_scores.append(single_image_scores)
            batch_classes.append(single_image_classes)
            batch_pred_bboxes.append(single_image_pred_bboxes)

        batch_scores = torch.cat(batch_scores, axis=0)
        batch_classes = torch.cat(batch_classes, axis=0)
        batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)

        # batch_scores shape:[batch_size,max_detection_num]
        # batch_classes shape:[batch_size,max_detection_num]
        # batch_pred_bboxes shape[batch_size,max_detection_num,4]
        return batch_scores, batch_classes, batch_pred_bboxes

    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]
        sorted_pred_bboxes_w_h = sorted_one_image_pred_bboxes[:,
                                                              2:] - sorted_one_image_pred_bboxes[:, :
                                                                                                 2]

        sorted_pred_bboxes_areas = sorted_pred_bboxes_w_h[:,
                                                          0] * sorted_pred_bboxes_w_h[:,
                                                                                      1]
        detected_classes = torch.unique(sorted_one_image_classes, sorted=True)

        keep_scores, keep_classes, keep_pred_bboxes = [], [], []
        for detected_class in detected_classes:
            single_class_scores = sorted_one_image_scores[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes = sorted_one_image_pred_bboxes[
                sorted_one_image_classes == detected_class]
            single_class_pred_bboxes_areas = sorted_pred_bboxes_areas[
                sorted_one_image_classes == detected_class]
            single_class = sorted_one_image_classes[sorted_one_image_classes ==
                                                    detected_class]

            single_keep_scores,single_keep_classes,single_keep_pred_bboxes=[],[],[]
            while single_class_scores.numel() > 0:
                top1_score, top1_class, top1_pred_bbox = single_class_scores[
                    0:1], single_class[0:1], single_class_pred_bboxes[0:1]

                single_keep_scores.append(top1_score)
                single_keep_classes.append(top1_class)
                single_keep_pred_bboxes.append(top1_pred_bbox)

                top1_areas = single_class_pred_bboxes_areas[0]

                if single_class_scores.numel() == 1:
                    break

                single_class_scores = single_class_scores[1:]
                single_class = single_class[1:]
                single_class_pred_bboxes = single_class_pred_bboxes[1:]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    1:]

                overlap_area_top_left = torch.max(
                    single_class_pred_bboxes[:, :2], top1_pred_bbox[:, :2])
                overlap_area_bot_right = torch.min(
                    single_class_pred_bboxes[:, 2:], top1_pred_bbox[:, 2:])
                overlap_area_sizes = torch.clamp(overlap_area_bot_right -
                                                 overlap_area_top_left,
                                                 min=0)
                overlap_area = overlap_area_sizes[:, 0] * overlap_area_sizes[:,
                                                                             1]

                # compute union_area
                union_area = top1_areas + single_class_pred_bboxes_areas - overlap_area
                union_area = torch.clamp(union_area, min=1e-4)
                # compute ious for top1 pred_bbox and the other pred_bboxes
                ious = overlap_area / union_area

                single_class_scores = single_class_scores[
                    ious < self.nms_threshold]
                single_class = single_class[ious < self.nms_threshold]
                single_class_pred_bboxes = single_class_pred_bboxes[
                    ious < self.nms_threshold]
                single_class_pred_bboxes_areas = single_class_pred_bboxes_areas[
                    ious < self.nms_threshold]

            single_keep_scores = torch.cat(single_keep_scores, axis=0)
            single_keep_classes = torch.cat(single_keep_classes, axis=0)
            single_keep_pred_bboxes = torch.cat(single_keep_pred_bboxes,
                                                axis=0)

            keep_scores.append(single_keep_scores)
            keep_classes.append(single_keep_classes)
            keep_pred_bboxes.append(single_keep_pred_bboxes)

        keep_scores = torch.cat(keep_scores, axis=0)
        keep_classes = torch.cat(keep_classes, axis=0)
        keep_pred_bboxes = torch.cat(keep_pred_bboxes, axis=0)

        return keep_scores, keep_classes, keep_pred_bboxes

    def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
            self, reg_heads, anchors):
        """
        snap reg heads to pred bboxes
        reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
        anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        """
        anchors_wh = anchors[:, 2:] - anchors[:, :2]
        anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh

        device = anchors.device
        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)

        reg_heads = reg_heads * factor

        pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
        pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr

        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh

        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
        pred_bboxes = pred_bboxes.int()

        pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
        pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
        pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
                                        max=self.image_w - 1)
        pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
                                        max=self.image_h - 1)

        # pred bboxes shape:[anchor_nums,4]
        return pred_bboxes

这样decode解码部分就实现好了。

你可能感兴趣的:(深度学习,pytorch)