mmdetection(1) : FCOS 代码解析

原文我就不贴了,说一下感受吧!

从检测方法出来,自我感觉一直不协调,现有的方法如fast系列一直比较复杂,强大的神经网络应该是简单高效的,one-stage从yolo出来后感觉好了很多,但是在最后的map上做roi anchor一直感觉特别冗余,还好corner出来了,但其也存在的问题,不同的pool结构,还有总感觉这种方式怪怪的。然后FCOS出来了,完全感受到了高效和简单,在此膜拜一下大神,感觉神经网络就应该这样,以最简单的方式,取得很好的效果。

FCOS: Fully Convolutional One-Stage Object Detection

简单介绍一下,方法很简单,好文。
mmdetection(1) : FCOS 代码解析_第1张图片

看到上图了吗,没错,就是这么粗暴,像素级的预测t,l,r,b,当然,上面存在着一个问题,就是一个点包含两个框,怎么搞,当然选小的了,不不不,FPN啊:

mmdetection(1) : FCOS 代码解析_第2张图片

稍微解释一下最后的框框,x4是4层卷积,class分类,距离中心的loss,回归,
mmdetection(1) : FCOS 代码解析_第3张图片
然后是中心度的定义,二进制交叉熵做loss。
在这里插入图片描述

ok,讲完了, 简单吧,粗暴吧,开始实现吧!


更新:
非常好用个点FCOS得配点代码不是,近期研究mmdetection,感觉很好用,特来吧源码解释一下

def loss(self,
             cls_scores,
             bbox_preds,
             centernesses,
             gt_bboxes,
             gt_labels,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):
 
        labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
                                                gt_labels)  #生成label
        flatten_cls_scores = [
            cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
            for cls_score in cls_scores
        ]
        flatten_bbox_preds = [
            bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
            for bbox_pred in bbox_preds
        ]
        flatten_centerness = [
            centerness.permute(0, 2, 3, 1).reshape(-1)
            for centerness in centernesses
        ]
        flatten_cls_scores = torch.cat(flatten_cls_scores)
        flatten_bbox_preds = torch.cat(flatten_bbox_preds)
        flatten_centerness = torch.cat(flatten_centerness)
        flatten_labels = torch.cat(labels)
        flatten_bbox_targets = torch.cat(bbox_targets)
     
        loss_cls = self.loss_cls(
            flatten_cls_scores, flatten_labels,
            avg_factor=num_pos + num_imgs)  # cls loss

        pos_bbox_preds = flatten_bbox_preds[pos_inds]
        pos_bbox_targets = flatten_bbox_targets[pos_inds]
        pos_centerness = flatten_centerness[pos_inds]
        pos_centerness_targets = self.centerness_target(pos_bbox_targets)


        pos_points = flatten_points[pos_inds]
        pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
        pos_decoded_target_preds = distance2bbox(pos_points,
                                                 pos_bbox_targets)
        # centerness weighted iou loss
        loss_bbox = self.loss_bbox(
            pos_decoded_bbox_preds,
            pos_decoded_target_preds,
            weight=pos_centerness_targets,
            avg_factor=pos_centerness_targets.sum())  #llox loss
        loss_centerness = self.loss_centerness(pos_centerness,
                                               pos_centerness_targets)  #center loss 交叉熵loss
      
        return dict(
            loss_cls=loss_cls,
            loss_bbox=loss_bbox,
            loss_centerness=loss_centerness)
            

其中生成中心点啊方法,和原文一样:

def centerness_target(self, pos_bbox_targets):
        # only calculate pos centerness targets, otherwise there may be nan
        left_right = pos_bbox_targets[:, [0, 2]]
        top_bottom = pos_bbox_targets[:, [1, 3]]
        centerness_targets = (
            left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
                top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
        return torch.sqrt(centerness_targets)

具体loss参考mmdet
focalloss+IOUloss+ CEloss

bbox_head=dict(
        type='FCOSHead',
        num_classes=81,
        in_channels=256,
        stacked_convs=4,
        feat_channels=256,
        strides=[8, 16, 32, 64, 128],
        loss_cls=dict(
            type='FocalLoss',
            use_sigmoid=True,
            gamma=2.0,
            alpha=0.25,
            loss_weight=1.0),
        loss_bbox=dict(type='IoULoss', loss_weight=1.0),
        loss_centerness=dict(
            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))

你可能感兴趣的:(检测算法-深度学习)