【mmdetection】FCOS代码阅读二

【mmdetection】FCOS

  • 损失函数

损失函数

下面我们再来看看函数里面定义的损失函数。

    @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
    def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
        '''
        cls_scores: [5][batchsize,80,H_i,W_i]
        bbox_preds: [5][batchsize,4,H_i,W_i]
        centernesses: [5][batchsize,1,H_i,W_i]
        gt_bboxes: [batchsize][num_obj,4]
        gt_labels: [batchsize][num_obj]
        img_metas: [batchsize][(dict)dict_keys(['filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'img_norm_cfg'])]
        cfg: {'assigner': {'type': 'MaxIoUAssigner', 'pos_iou_thr': 0.5, 'neg_iou_thr': 0.4, 'min_pos_iou': 0, 'ignore_iof_thr': -1}, 'allowed_border': -1, 'pos_weight': -1, 'debug': False}
        '''
        assert len(cls_scores) == len(bbox_preds) == len(centernesses) # 5

        featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] # P3-P7特征图的大小
        '''
        [torch.Size([100, 152]),
        torch.Size([50, 76]),
        torch.Size([25, 38]),
        torch.Size([13, 19]),
        torch.Size([7, 10])]
        '''
        
        # 特征图的大小就相当于把原图分为多大的grid,特征图每个像素映射到原图就是该grid的中心点,不同大小的特征图就有不同的grid
        # bbox_preds[0].dtype:torch.float32
        # all_level_points:(list) [5][n_points][2]
        all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
                                           bbox_preds[0].device)
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)
        '''
        labels:[5][batch_size*level_points_i]
        bbox_targets:[5][batch_size*level_points_i,4]
        '''
        num_imgs = cls_scores[0].size(0)
        # flatten cls_scores, bbox_preds and centerness
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores) # torch.Size([89600, 80]) 所有图片所有point的5个层的输出
        flatten_bbox_preds = torch.cat(flatten_bbox_preds) # torch.Size([89600, 4]) 
        flatten_centerness = torch.cat(flatten_centerness) # torch.Size([89600]) 
        flatten_labels = torch.cat(labels) # torch.Size([89600]) 
        flatten_bbox_targets = torch.cat(bbox_targets) # torch.Size([89600, 4])
        # repeat points to align with bbox_preds
        flatten_points = torch.cat(
            [points.repeat(num_imgs, 1) for points in all_level_points]) # torch.Size([89600, 2])

        pos_inds = flatten_labels.nonzero().reshape(-1) 
        num_pos = len(pos_inds)
        loss_cls = self.loss_cls(
            flatten_cls_scores, flatten_labels,
            avg_factor=num_pos + num_imgs)  # avoid num_pos is 0

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]

        if num_pos > 0:
            pos_bbox_targets = flatten_bbox_targets[pos_inds]
            pos_centerness_targets = self.centerness_target(pos_bbox_targets)
            pos_points = flatten_points[pos_inds]
            # 预测的是距离,解码成坐标
            pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds) # mmdet/core/bbox/transfrom
            pos_decoded_target_preds = distance2bbox(pos_points,
                                                     pos_bbox_targets)
            # centerness weighted iou loss
            loss_bbox = self.loss_bbox(
                pos_decoded_bbox_preds,
                pos_decoded_target_preds,
                weight=pos_centerness_targets,
                avg_factor=pos_centerness_targets.sum())
            loss_centerness = self.loss_centerness(pos_centerness,
                                                   pos_centerness_targets)
        else:
            loss_bbox = pos_bbox_preds.sum()
            loss_centerness = pos_centerness.sum()

        return dict(
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_centerness=loss_centerness)

来看看get_points的细节

def get_points(self, featmap_sizes, dtype, device):
        """Get points according to feature map sizes.

        Args:
            featmap_sizes (list[tuple]): Multi-level feature map sizes.
            dtype (torch.dtype): Type of points.
            device (torch.device): Device of points.

        Returns:
            tuple: points of each image.
        """
        mlvl_points = []
        for i in range(len(featmap_sizes)):
            mlvl_points.append(
                self.get_points_single(featmap_sizes[i], self.strides[i],
                                       dtype, device))
        return mlvl_points

    def get_points_single(self, featmap_size, stride, dtype, device):
        h, w = featmap_size # eg 100,152
        x_range = torch.arange(
            0, w * stride, stride, dtype=dtype, device=device)
        '''
        tensor([   0.,    8.,   16.,   24.,   32.,   40.,   48.,   56.,   64.,   72.,
          80.,   88.,   96.,  104.,  112.,  120.,  128.,  136.,  144.,  152.,
         160.,  168.,  176.,  184.,  192.,  200.,  208.,  216.,  224.,  232.,
         240.,  248.,  256.,  264.,  272.,  280.,  288.,  296.,  304.,  312.,
         320.,  328.,  336.,  344.,  352.,  360.,  368.,  376.,  384.,  392.,
         400.,  408.,  416.,  424.,  432.,  440.,  448.,  456.,  464.,  472.,
         480.,  488.,  496.,  504.,  512.,  520.,  528.,  536.,  544.,  552.,
         560.,  568.,  576.,  584.,  592.,  600.,  608.,  616.,  624.,  632.,
         640.,  648.,  656.,  664.,  672.,  680.,  688.,  696.,  704.,  712.,
         720.,  728.,  736.,  744.,  752.,  760.,  768.,  776.,  784.,  792.,
         800.,  808.,  816.,  824.,  832.,  840.,  848.,  856.,  864.,  872.,
         880.,  888.,  896.,  904.,  912.,  920.,  928.,  936.,  944.,  952.,
         960.,  968.,  976.,  984.,  992., 1000., 1008., 1016., 1024., 1032.,
        1040., 1048., 1056., 1064., 1072., 1080., 1088., 1096., 1104., 1112.,
        1120., 1128., 1136., 1144., 1152., 1160., 1168., 1176., 1184., 1192.,
        1200., 1208.], device='cuda:0')
        '''
        y_range = torch.arange(
            0, h * stride, stride, dtype=dtype, device=device)
        '''
        tensor([  0.,   8.,  16.,  24.,  32.,  40.,  48.,  56.,  64.,  72.,  80.,  88.,
         96., 104., 112., 120., 128., 136., 144., 152., 160., 168., 176., 184.,
        192., 200., 208., 216., 224., 232., 240., 248., 256., 264., 272., 280.,
        288., 296., 304., 312., 320., 328., 336., 344., 352., 360., 368., 376.,
        384., 392., 400., 408., 416., 424., 432., 440., 448., 456., 464., 472.,
        480., 488., 496., 504., 512., 520., 528., 536., 544., 552., 560., 568.,
        576., 584., 592., 600., 608., 616., 624., 632., 640., 648., 656., 664.,
        672., 680., 688., 696., 704., 712., 720., 728., 736., 744., 752., 760.,
        768., 776., 784., 792.], device='cuda:0')
        '''
        y, x = torch.meshgrid(y_range, x_range)
        '''
        y
        tensor([[  0.,   0.,   0.,  ...,   0.,   0.,   0.],
        [  8.,   8.,   8.,  ...,   8.,   8.,   8.],
        [ 16.,  16.,  16.,  ...,  16.,  16.,  16.],
        ...,
        [776., 776., 776.,  ..., 776., 776., 776.],
        [784., 784., 784.,  ..., 784., 784., 784.],
        [792., 792., 792.,  ..., 792., 792., 792.]], device='cuda:0')
		
		x
		tensor([[   0.,    8.,   16.,  ..., 1192., 1200., 1208.],
        [   0.,    8.,   16.,  ..., 1192., 1200., 1208.],
        [   0.,    8.,   16.,  ..., 1192., 1200., 1208.],
        ...,
        [   0.,    8.,   16.,  ..., 1192., 1200., 1208.],
        [   0.,    8.,   16.,  ..., 1192., 1200., 1208.],
        [   0.,    8.,   16.,  ..., 1192., 1200., 1208.]], device='cuda:0')
        '''
        points = torch.stack(
            (x.reshape(-1), y.reshape(-1)), dim=-1) + stride // 2
        '''
        tensor([[   4.,    4.],
        [  12.,    4.],
        [  20.,    4.],
        ...,
        [1196.,  796.],
        [1204.,  796.],
        [1212.,  796.]], device='cuda:0')
        '''
        return points

fcos_target就是为各level的特征点(也就是原图上的每个grid的中心点)生成target,正样本的点是中心点在gt box里面,并且满足每层fpn输出大小限制的。最新版的论文提到了center sampling,并不是下方gt box里的都是正样本。

  def fcos_target(self, points, gt_bboxes_list, gt_labels_list):
        '''
        points:(list) [5][n_points][2])
        gt_bboxes_list: [batch_size][num_objects,4]
        gt_labels_list: [batch_size][num_objects]
        '''
        assert len(points) == len(self.regress_ranges)
        num_levels = len(points) # 5
        # expand regress ranges to align with points 
        expanded_regress_ranges = [
            points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
                points[i]) for i in range(num_levels)
        ] # (list)[5][n_points][2])

        # concat all levels points and regress ranges
        concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
        concat_points = torch.cat(points, dim=0)
        # get labels and bbox_targets of each image
        labels_list, bbox_targets_list = multi_apply(
            self.fcos_target_single,
            gt_bboxes_list,
            gt_labels_list,
            points=concat_points,
            regress_ranges=concat_regress_ranges)

        # labels_list:[batch_size][total_points]
        # bbox_targets_list: [batch_size][total_points,4]
        # split to per img, per level
        num_points = [center.size(0) for center in points]
        labels_list = [labels.split(num_points, 0) for labels in labels_list] #[batch_size][5][level_points_i] 
        bbox_targets_list = [
            bbox_targets.split(num_points, 0)
            for bbox_targets in bbox_targets_list
        ] #[batch_size][5][level_points_i,4]

        # concat per level image
        concat_lvl_labels = []
        concat_lvl_bbox_targets = []
        for i in range(num_levels):
            concat_lvl_labels.append(
                torch.cat([labels[i] for labels in labels_list]))
            concat_lvl_bbox_targets.append(
                torch.cat(
                    [bbox_targets[i] for bbox_targets in bbox_targets_list]))
        # concat_lvl_labels:[5][batch_size*level_points_i]
        # concat_lvl_bbox_targets:[5][batch_size*level_points_i,4]
        return concat_lvl_labels, concat_lvl_bbox_targets

    def fcos_target_single(self, gt_bboxes, gt_labels, points, regress_ranges):
        '''
        gt_bboxes: [num_objects,4] xmin,ymin,xmax,ymax
        gt_labels: [num_objects]
        points: [5_featuremap_total_points,2]
        regress_ranges: [5_featuremap_total_points,2] 距离边框的距离中的最大值所在范围
        '''
        num_points = points.size(0)
        num_gts = gt_labels.size(0)
        if num_gts == 0:
            return gt_labels.new_zeros(num_points), \
                   gt_bboxes.new_zeros((num_points, 4))

        areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (
            gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)
        # TODO: figure out why these two are different
        # areas = areas[None].expand(num_points, num_gts)
        areas = areas[None].repeat(num_points, 1) # [num_points, num_gts]
        regress_ranges = regress_ranges[:, None, :].expand(
            num_points, num_gts, 2) # [num_points, num_gts, 2]
        gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4) # [num_points, num_gts, 2]

        xs, ys = points[:, 0], points[:, 1]
        xs = xs[:, None].expand(num_points, num_gts)
        ys = ys[:, None].expand(num_points, num_gts)

        left = xs - gt_bboxes[..., 0]
        right = gt_bboxes[..., 2] - xs
        top = ys - gt_bboxes[..., 1]
        bottom = gt_bboxes[..., 3] - ys
        bbox_targets = torch.stack((left, top, right, bottom), -1) # [num_points, num_gts, 4]

        # condition1: inside a gt bbox
        inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0 # 只要任何距离边框距离最小值小于0,说明中心点location在gt box之外

        # condition2: limit the regression range for each location
        max_regress_distance = bbox_targets.max(-1)[0]
        inside_regress_range = (
            max_regress_distance >= regress_ranges[..., 0]) & (
                max_regress_distance <= regress_ranges[..., 1])

        # if there are still more than one objects for a location,
        # we choose the one with minimal area
        areas[inside_gt_bbox_mask == 0] = INF
        areas[inside_regress_range == 0] = INF
        min_area, min_area_inds = areas.min(dim=1) # [num_points, num_gts]

        labels = gt_labels[min_area_inds] #[num_points]
        labels[min_area == INF] = 0
        bbox_targets = bbox_targets[range(num_points), min_area_inds]

        return labels, bbox_targets

centerness_target就是字面意思,通过下式来表示当前位置和物体中心间的距离,如果越远离,值就越接近0,反之越接近1。但只对正样本计算。
【mmdetection】FCOS代码阅读二_第1张图片

def centerness_target(self, pos_bbox_targets):
        # only calculate pos centerness targets, otherwise there may be nan
        left_right = pos_bbox_targets[:, [0, 2]]
        top_bottom = pos_bbox_targets[:, [1, 3]]
        centerness_targets = (
            left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
                top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
        return torch.sqrt(centerness_targets)

你可能感兴趣的:(mmdetection)