【目标检测】后处理再加速!几种不同的NMS后处理方法速度表现

文章目录

  • Fast NMS
  • torchvision.ops.nms()
  • 几种NMS后处理方法的速度与性能表现

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。

除了之前在从零实现RetinaNet(终)中使用的NMS后处理方法,我又尝试了两种NMS后处理方法。一种是YOLACT(https://arxiv.org/pdf/1904.02689.pdf)中提出的Fast NMS,另一种是torchvision自带的torchvision.ops.nms()。注意torchvision.ops中的op均不支持TorchScript。

Fast NMS

Fast NMS把传统NMS的迭代式计算方式改成了用矩阵计算一次性得出结果的计算方式。传统的NMS算法先把所有框按分类得分从大到小排序,然后进行迭代,每次迭代先保留最高分类得分的框,然后计算其他框与该框的IoU,对于IoU大于阈值的框就删除,反复迭代直到没有候选框为止。
Fast NMS算法先将所有框按分类得分从大到小排序,然后计算出所有框两两框之间的IoU,得到一个对称矩阵。然后将矩阵上三角化,且左上到右下的对角线元素也置为0(每个框与自己的IoU),然后按维度0从矩阵取最大IoU,再判断每个IoU是否大于过滤阈值,对于大于阈值的框进行过滤。实际上就是从最高得分的框开始每个框找到一个与该框最大IoU的框,如果这个框IoU超过了阈值就过滤掉。并且由于矩阵是上三角矩阵,后面的框过滤时前面的框不会对其造成干扰。

Fast NMS代码实现如下:

# 用这两个函数替换原来RetinaDecoder中的NMS即可。
    def nms(self, one_image_scores, one_image_classes, one_image_pred_bboxes):
        """
        one_image_scores:[anchor_nums],4:classification predict scores
        one_image_classes:[anchor_nums],class indexes for predict scores
        one_image_pred_bboxes:[anchor_nums,4],4:x_min,y_min,x_max,y_max
        """
        device = one_image_scores.device
        final_scores = (-1) * torch.ones(
            (self.max_detection_num, ), device=device)
        final_classes = (-1) * torch.ones(
            (self.max_detection_num, ), device=device)
        final_pred_bboxes = (-1) * torch.ones(
            (self.max_detection_num, 4), device=device)

        if one_image_scores.shape[0] == 0:
            return final_scores, final_classes, final_pred_bboxes

        # Sort boxes
        sorted_one_image_scores, sorted_one_image_scores_indexes = torch.sort(
            one_image_scores, descending=True)
        sorted_one_image_classes = one_image_classes[
            sorted_one_image_scores_indexes]
        sorted_one_image_pred_bboxes = one_image_pred_bboxes[
            sorted_one_image_scores_indexes]

        ious = self.box_iou(sorted_one_image_pred_bboxes,
                            sorted_one_image_pred_bboxes)

        #Triangulation on matrix
        ious = torch.triu(ious, diagonal=1)

        keep = ious.max(dim=0)[0]
        keep = keep < self.nms_threshold

        keep_scores = sorted_one_image_scores[keep]
        keep_classes = sorted_one_image_classes[keep]
        keep_pred_bboxes = sorted_one_image_pred_bboxes[keep]

        final_detection_num = min(self.max_detection_num, keep_scores.shape[0])

        final_scores[0:final_detection_num] = keep_scores[
            0:final_detection_num]
        final_classes[0:final_detection_num] = keep_classes[
            0:final_detection_num]
        final_pred_bboxes[0:final_detection_num, :] = keep_pred_bboxes[
            0:final_detection_num, :]

        return final_scores, final_classes, final_pred_bboxes

    def box_iou(self, boxes1, boxes2):
        """
        boxes1:[N, 4]
        boxes2:[M, 4]
        ious:[N, M]
        """
        area1 = (boxes1.t()[2] - boxes1.t()[0]) * (boxes1.t()[3] -
                                                   boxes1.t()[1])
        area2 = (boxes2.t()[2] - boxes2.t()[0]) * (boxes2.t()[3] -
                                                   boxes2.t()[1])

        # boxes1[:, None, :2] shape:[4125, 1, 2], boxes2[:, :2] shape:[4125, 2]
        overlap_area_left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2])
        overlap_area_right_bot = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])

        overlap_area_sizes = (overlap_area_right_bot -
                              overlap_area_left_top).clamp(min=0)
        overlap_area = overlap_area_sizes[:, :, 0] * overlap_area_sizes[:, :,
                                                                        1]
        ious = overlap_area / (area1[:, None] + area2 - overlap_area)

        return ious

torchvision.ops.nms()

torchvision.ops.nms的NMS后处理做法与我在从零实现RetinaNet(终)中提到的NMS后处理方法完全一致,但是因为是用C++实现的,速度上要比我实现的NMS要快。

使用torchvision.ops.nms时RetinaDecoder的代码实现如下:

class RetinaDecoder(nn.Module):
    def __init__(self,
                 image_w,
                 image_h,
                 top_n=1000,
                 min_score_threshold=0.05,
                 nms_threshold=0.5,
                 max_detection_num=100):
        super(RetinaDecoder, self).__init__()
        self.image_w = image_w
        self.image_h = image_h
        self.top_n = top_n
        self.min_score_threshold = min_score_threshold
        self.nms_threshold = nms_threshold
        self.max_detection_num = max_detection_num

    def forward(self, cls_heads, reg_heads, batch_anchors):
        device = cls_heads[0].device
        with torch.no_grad():
            filter_scores,filter_score_classes,filter_reg_heads,filter_batch_anchors=[],[],[],[]
            for per_level_cls_head, per_level_reg_head, per_level_anchor in zip(
                    cls_heads, reg_heads, batch_anchors):
                scores, score_classes = torch.max(per_level_cls_head, dim=2)
                if scores.shape[1] >= self.top_n:
                    scores, indexes = torch.topk(scores,
                                                 self.top_n,
                                                 dim=1,
                                                 largest=True,
                                                 sorted=True)
                    score_classes = torch.gather(score_classes, 1, indexes)
                    per_level_reg_head = torch.gather(
                        per_level_reg_head, 1,
                        indexes.unsqueeze(-1).repeat(1, 1, 4))
                    per_level_anchor = torch.gather(
                        per_level_anchor, 1,
                        indexes.unsqueeze(-1).repeat(1, 1, 4))

                filter_scores.append(scores)
                filter_score_classes.append(score_classes)
                filter_reg_heads.append(per_level_reg_head)
                filter_batch_anchors.append(per_level_anchor)

            filter_scores = torch.cat(filter_scores, axis=1)
            filter_score_classes = torch.cat(filter_score_classes, axis=1)
            filter_reg_heads = torch.cat(filter_reg_heads, axis=1)
            filter_batch_anchors = torch.cat(filter_batch_anchors, axis=1)

            batch_scores, batch_classes, batch_pred_bboxes = [], [], []
            for per_image_scores, per_image_score_classes, per_image_reg_heads, per_image_anchors in zip(
                    filter_scores, filter_score_classes, filter_reg_heads,
                    filter_batch_anchors):
                pred_bboxes = self.snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
                    per_image_reg_heads, per_image_anchors)
                score_classes = per_image_score_classes[
                    per_image_scores > self.min_score_threshold].float()
                pred_bboxes = pred_bboxes[
                    per_image_scores > self.min_score_threshold].float()
                scores = per_image_scores[
                    per_image_scores > self.min_score_threshold].float()

                one_image_scores = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_classes = (-1) * torch.ones(
                    (self.max_detection_num, ), device=device)
                one_image_pred_bboxes = (-1) * torch.ones(
                    (self.max_detection_num, 4), device=device)

                if scores.shape[0] != 0:
                    # Sort boxes
                    sorted_scores, sorted_indexes = torch.sort(scores,
                                                               descending=True)
                    sorted_score_classes = score_classes[sorted_indexes]
                    sorted_pred_bboxes = pred_bboxes[sorted_indexes]

                    keep = nms(sorted_pred_bboxes, sorted_scores,
                               self.nms_threshold)
                    keep_scores = sorted_scores[keep]
                    keep_classes = sorted_score_classes[keep]
                    keep_pred_bboxes = sorted_pred_bboxes[keep]

                    final_detection_num = min(self.max_detection_num,
                                              keep_scores.shape[0])

                    one_image_scores[0:final_detection_num] = keep_scores[
                        0:final_detection_num]
                    one_image_classes[0:final_detection_num] = keep_classes[
                        0:final_detection_num]
                    one_image_pred_bboxes[
                        0:final_detection_num, :] = keep_pred_bboxes[
                            0:final_detection_num, :]

                one_image_scores = one_image_scores.unsqueeze(0)
                one_image_classes = one_image_classes.unsqueeze(0)
                one_image_pred_bboxes = one_image_pred_bboxes.unsqueeze(0)

                batch_scores.append(one_image_scores)
                batch_classes.append(one_image_classes)
                batch_pred_bboxes.append(one_image_pred_bboxes)

            batch_scores = torch.cat(batch_scores, axis=0)
            batch_classes = torch.cat(batch_classes, axis=0)
            batch_pred_bboxes = torch.cat(batch_pred_bboxes, axis=0)

            # batch_scores shape:[batch_size,max_detection_num]
            # batch_classes shape:[batch_size,max_detection_num]
            # batch_pred_bboxes shape[batch_size,max_detection_num,4]
            return batch_scores, batch_classes, batch_pred_bboxes

    def snap_tx_ty_tw_th_reg_heads_to_x1_y1_x2_y2_bboxes(
            self, reg_heads, anchors):
        """
        snap reg heads to pred bboxes
        reg_heads:[anchor_nums,4],4:[tx,ty,tw,th]
        anchors:[anchor_nums,4],4:[x_min,y_min,x_max,y_max]
        """
        anchors_wh = anchors[:, 2:] - anchors[:, :2]
        anchors_ctr = anchors[:, :2] + 0.5 * anchors_wh

        device = anchors.device
        factor = torch.tensor([[0.1, 0.1, 0.2, 0.2]]).to(device)

        reg_heads = reg_heads * factor

        pred_bboxes_wh = torch.exp(reg_heads[:, 2:]) * anchors_wh
        pred_bboxes_ctr = reg_heads[:, :2] * anchors_wh + anchors_ctr

        pred_bboxes_x_min_y_min = pred_bboxes_ctr - 0.5 * pred_bboxes_wh
        pred_bboxes_x_max_y_max = pred_bboxes_ctr + 0.5 * pred_bboxes_wh

        pred_bboxes = torch.cat(
            [pred_bboxes_x_min_y_min, pred_bboxes_x_max_y_max], axis=1)
        pred_bboxes = pred_bboxes.int()

        pred_bboxes[:, 0] = torch.clamp(pred_bboxes[:, 0], min=0)
        pred_bboxes[:, 1] = torch.clamp(pred_bboxes[:, 1], min=0)
        pred_bboxes[:, 2] = torch.clamp(pred_bboxes[:, 2],
                                        max=self.image_w - 1)
        pred_bboxes[:, 3] = torch.clamp(pred_bboxes[:, 3],
                                        max=self.image_h - 1)

        # pred bboxes shape:[anchor_nums,4]
        return pred_bboxes

几种NMS后处理方法的速度与性能表现

下表中测试的都是同一个模型,即从零实现RetinaNet(终)中的ResNet50-RetinaNet-myresize667-fastdecode。fastdecode即该文中提到的后处理方法。测试时batch=1,resize=667,测完COCO2017_val中所有的图片的时间再除以图片总数就是speed。speed速度单位为ms。
测试时均使用单张GTX 1070 Max-Q。

Network NMS method epoch12-mAP-speed
ResNet50-RetinaNet-myresize667 fastdecode 0.293,154
ResNet50-RetinaNet-myresize667 fast nms 0.282,128
ResNet50-RetinaNet-myresize667 torchvision.ops.nms() 0.293,118

可以看到torchvision.ops.nms()的速度最快,而且各项性能指标与我实现的fastdecode是完全一样的。fast nms速度要比我的fastdecode要快,但是由于其YOLACT中是用在分割上的,用在检测上掉点幅度较大。

你可能感兴趣的:(深度学习,pytorch)