目标检测3--AnchorFree的FCOS

文章目录

    • 1.介绍
    • 2.FCOS中使用的方法
      • 2.1 网络结构
      • 2.2FCOS中使用`FPN`的多级预测
      • 2.3FCOS中的中心度
    • 3.mmdetection中`FCOS`源码
    • 参考资料


欢迎访问个人网络日志知行空间


1.介绍

论文:《FCOS: Fully Convolutional One-Stage Object Detection》
是澳洲阿德莱德大学的Zhi Tian等最早于2019年04月提交的工作成果,发表在ICCV上。

FCOS是全卷积实现的Anchor Free的一阶目标检测器,避免了训练过程中Anchor相关的计算,减少的训练时的计算量和内存占用,移除了anchor相关的一系列超参数。

Anchor Based方法的缺点:

  • 检测性能对anchorsize/aspect ratio/数量比较敏感。
  • 实际对象的检测框大小分布较广泛,anchor不一定能覆盖
  • 为了得到高召回率,anchor based的方法返回了非常多的anchor box,如FPN中,输入短边为800的图像将总共生成大于180KAnchor Box。超级多的anchor box除了影响性能外,还导致了严重的类别不平衡问题,因为180Kanchor box中有大量的都是不包含对象的negative box

neat fully convolution pixel prediction frameworkanchor box的存在不适用于object detection。先前的预测方法在目标检测框重叠时检测效果不好,如DenseBoxFCOS在距离目标中心较远的位置产生很多低质量的预测边框,这会影响检测的效果,为了克服这个问题,引入了中心度的概念,衡量预测框到距目标框的距离。FCOS添加单层分支,与分类分支并行,以预测"Center-ness"位置。

FCOS的优点:

  • FCN结构,与pixel prediction任务的网络统一,更便于复用语义分割中的tricks
  • anchor free框架,减少模型的设计参数。
  • anchor free框架,减少训练中的IoU计算和box match.
  • FCOS可用于two-stage检测网络中的RPN
  • 易于拓展应用于其他视觉任务,如Instance Segmentation

Anchor Free目标检测器有YoloV1,CornerNet

2.FCOS中使用的方法

2.1 网络结构

将目标检测任务形式化为pixel prediction任务,使用多级预测提升召回率,解决重叠目标的模糊问题。

特征图上的点(x,y)可以重新映射到输入原图上:

( ⌊ s 2 ⌋ + x s , ⌊ s 2 ⌋ + y s ) (\lfloor \frac{s}{2}\rfloor+xs, \lfloor \frac{s}{2}\rfloor+ys) (⌊2s+xs,2s+ys)

FCOS直接将(x, y)对应的检测框位置当作训练样本,具体是指,当(x,y)映射到原图上落入任何ground-truth box中,这个点就当成正样本,否则就是负样本。检测定位回归的变量是 t ∗ = ( l ∗ , t ∗ , r ∗ , b ∗ ) t^*=(l^*, t^*, r^*, b^*) t=(l,t,r,b), t ∗ t^* t分别指中心(x, y)距检测框左/上/右/下四条边的距离。

目标检测3--AnchorFree的FCOS_第1张图片

feature map上的某个点(x,y)同时落入多个bounding box中时,这个点(x,y)被当成ambiguous sampleFPNmulti level机制可以用来解决这个问题。很显然,通过上述介绍可以知道,FCOS中可能有多个feature map上的点(x, y)落入到同个ground truth box中,故可以生成很多positive boxes用于训练。

目标检测3--AnchorFree的FCOS_第2张图片

网络输出

训练了C个二分类器,而非1个多分类器,可以实现多标签预测。feature map后接有4个卷积层的分类分支位置回归分支,回归分支使用了exp(x),将回归预测变量x变换到了(0, +\infinity)上,输出的变量比anchor based方法少9倍。

损失函数

目标检测3--AnchorFree的FCOS_第3张图片

分类使用的是Focal Loss解决类别不平衡问题,回归使用的是IoU损失。

2.2FCOS中使用FPN的多级预测

FCOS中的两个问题:

  • 1) 到最终输出的feature map使用大stridex16会导致低BPR(best possible recall)Anchor Based方法可以对positive box使用低的IoU阈值来补偿这个问题,而对于FCOS,直观的一个猜想是,对于小物体,stride过大时,可能导致feature map上并没有一个点可以与小目标的中心对应,故BPR应该会比较低。
  • 2) 重叠的目标框引入了棘手的模糊性问题,feature map中的一个点(x, y)如何确定应该用来回归重叠框中的哪一个呢?

使用FPN的多级预测解决FCOS中存在的这两个问题。

FPN,使用不同层级中的feature map来预测不同大小的目标。限制不同level feature map回归距离 t ∗ = ( l ∗ , t ∗ , r ∗ , b ∗ ) t^*=(l^*, t^*, r^*, b^*) t=(l,t,r,b)的上限, m i m_i mi表示level ifeature map上每个点回归的最大距离,因此若level ifeature map上某个点回归的距离 t ∗ = ( l ∗ , t ∗ , r ∗ , b ∗ ) t^*=(l^*, t^*, r^*, b^*) t=(l,t,r,b) 大于 m i m_i mi或小于 m i − 1 m_{i-1} mi1时,就将其当作negative box对于 { P 3 , P 4 , P 5 , P 6 , P 7 } \{P_3,P_4,P_5,P_6,P_7\} {P3,P4,P5,P6,P7}对应的 m i m_i mi分别为 0 , 64 , 128 , 256 , 512 , + ∞ 0, 64, 128, 256, 512, +\infty 0,64,128,256,512,+

前面介绍的因预测的是距离,故 t ∗ ∈ ( 0 , ∞ ) t^*\in(0, \infty) t(0,)因此使用了exp函数,现在不同level的回归距离范围是不同的,再使用相同的head就不合理了,增加一个参数 s i s_i si,使用 e x p ( s i x ) exp(s_ix) exp(six)函数,对不同level使用不同的 s i s_i si

2.3FCOS中的中心度

可能有多个feature map中的点对应同个物体检测框,而距离中心较远的feature map点会导致引入很多低质量的物体检测框,影响检测的效果。

FCOS引入了中心度centerness来描述一个点距离目标框中心的远近,以过滤掉偏离中心点的低质量检测框。中心度是通过一个与分类分支并行的单层分支来预测的。对于某个位置的回归目标 ( l ∗ , t ∗ , r ∗ , b ∗ ) (l^*, t^*, r^*, b^*) (l,t,r,b),其中心度的定义为:

c e n t e r n e s s ∗ = m i n ( l ∗ , r ∗ ) m a x ( l ∗ , r ∗ ) × m i n ( t ∗ , b ∗ ) m a x ( t ∗ , b ∗ ) centerness^*=\sqrt{\frac{min(l^*,r^*)}{max(l^*,r^*)}\times \frac{min(t^*,b^*)}{max(t^*,b^*)}} centerness=max(l,r)min(l,r)×max(t,b)min(t,b)

sqrt 运算可以减缓中心度的衰减速度。centerness取值范围0-1,训练使用二分类交叉熵BCE,加到前述的损失函数中一起训练,预测时,将centerness与分类评分score相乘后作为最后检测框的评分,再进行NMS,因此很多偏离中心的框就能被过滤掉了。

3.mmdetection中FCOS源码

FCOS网络结构的定义如上图中二所示,定义的文件在mmdetection工程``文件中。

看一下FCOS Head,从计算loss是使用的self.get_targets方法开始。

min
feat_points
bbox_targets
gt_boxes
center_sample_radius
inside_gt_bbox_mask
regress_ranges
inside_regress_range
areas
min_area_inds
labels
labels
bbox_targets
def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges,
                         num_points_per_lvl):
     """Compute regression and classification targets for a single image."""
     num_points = points.size(0)
     num_gts = gt_labels.size(0)
     if num_gts == 0:
          return gt_labels.new_full((num_points,), self.num_classes), \
               gt_bboxes.new_zeros((num_points, 4))

     areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
          gt_bboxes[:, 3] - gt_bboxes[:, 1])
     # TODO: figure out why these two are different
     # areas = areas[None].expand(num_points, num_gts)
     areas = areas[None].repeat(num_points, 1)
     regress_ranges = regress_ranges[:, None, :].expand(
          num_points, num_gts, 2)
     gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
     xs, ys = points[:, 0], points[:, 1]
     xs = xs[:, None].expand(num_points, num_gts)
     ys = ys[:, None].expand(num_points, num_gts)

     left = xs - gt_bboxes[..., 0]
     right = gt_bboxes[..., 2] - xs
     top = ys - gt_bboxes[..., 1]
     bottom = gt_bboxes[..., 3] - ys
     bbox_targets = torch.stack((left, top, right, bottom), -1)

     if self.center_sampling:
          # condition1: inside a `center bbox`
          radius = self.center_sample_radius
          center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
          center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
          center_gts = torch.zeros_like(gt_bboxes)
          stride = center_xs.new_zeros(center_xs.shape)

          # project the points on current lvl back to the `original` sizes
          lvl_begin = 0
          for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
               lvl_end = lvl_begin + num_points_lvl
               stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
               lvl_begin = lvl_end

          x_mins = center_xs - stride
          y_mins = center_ys - stride
          x_maxs = center_xs + stride
          y_maxs = center_ys + stride
          center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
                                        x_mins, gt_bboxes[..., 0])
          center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
                                        y_mins, gt_bboxes[..., 1])
          center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
                                        gt_bboxes[..., 2], x_maxs)
          center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
                                        gt_bboxes[..., 3], y_maxs)

          cb_dist_left = xs - center_gts[..., 0]
          cb_dist_right = center_gts[..., 2] - xs
          cb_dist_top = ys - center_gts[..., 1]
          cb_dist_bottom = center_gts[..., 3] - ys
          center_bbox = torch.stack(
               (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
          inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
     else:
          # condition1: inside a gt bbox
          inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0

     # condition2: limit the regression range for each location
     max_regress_distance = bbox_targets.max(-1)[0]
     inside_regress_range = (
          (max_regress_distance >= regress_ranges[..., 0])
          & (max_regress_distance <= regress_ranges[..., 1]))

     # if there are still more than one objects for a location,
     # we choose the one with minimal area
     areas[inside_gt_bbox_mask == 0] = INF
     areas[inside_regress_range == 0] = INF
     min_area, min_area_inds = areas.min(dim=1)

     labels = gt_labels[min_area_inds]
     labels[min_area == INF] = self.num_classes  # set as BG
     bbox_targets = bbox_targets[range(num_points), min_area_inds]

     return labels, bbox_targets

中心度的计算:

def centerness_target(self, pos_bbox_targets):
        """Compute centerness targets.
        Args:
            pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape
                (num_pos, 4)
        Returns:
            Tensor: Centerness target.
        """
        # only calculate pos centerness targets, otherwise there may be nan
        left_right = pos_bbox_targets[:, [0, 2]]
        top_bottom = pos_bbox_targets[:, [1, 3]]
        if len(left_right) == 0:
            centerness_targets = left_right[..., 0]
        else:
            centerness_targets = (
                left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
                    top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
        return torch.sqrt(centerness_targets)

Focal Loss的定义:

pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
               (1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
     pred, target, reduction='none') * focal_weight

参考资料

  • 1.https://zhuanlan.zhihu.com/p/63868458
  • 2.https://zhuanlan.zhihu.com/p/32423092

你可能感兴趣的:(计算机视觉)