原文我就不贴了,说一下感受吧!
从检测方法出来,自我感觉一直不协调,现有的方法如fast系列一直比较复杂,强大的神经网络应该是简单高效的,one-stage从yolo出来后感觉好了很多,但是在最后的map上做roi anchor一直感觉特别冗余,还好corner出来了,但其也存在的问题,不同的pool结构,还有总感觉这种方式怪怪的。然后FCOS出来了,完全感受到了高效和简单,在此膜拜一下大神,感觉神经网络就应该这样,以最简单的方式,取得很好的效果。
FCOS: Fully Convolutional One-Stage Object Detection
看到上图了吗,没错,就是这么粗暴,像素级的预测t,l,r,b,当然,上面存在着一个问题,就是一个点包含两个框,怎么搞,当然选小的了,不不不,FPN啊:
稍微解释一下最后的框框,x4是4层卷积,class分类,距离中心的loss,回归,
然后是中心度的定义,二进制交叉熵做loss。
ok,讲完了, 简单吧,粗暴吧,开始实现吧!
更新:
非常好用个点FCOS得配点代码不是,近期研究mmdetection,感觉很好用,特来吧源码解释一下
def loss(self,
cls_scores,
bbox_preds,
centernesses,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore=None):
labels, bbox_targets = self.fcos_target(all_level_points, gt_bboxes,
gt_labels) #生成label
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds
]
flatten_centerness = [
centerness.permute(0, 2, 3, 1).reshape(-1)
for centerness in centernesses
]
flatten_cls_scores = torch.cat(flatten_cls_scores)
flatten_bbox_preds = torch.cat(flatten_bbox_preds)
flatten_centerness = torch.cat(flatten_centerness)
flatten_labels = torch.cat(labels)
flatten_bbox_targets = torch.cat(bbox_targets)
loss_cls = self.loss_cls(
flatten_cls_scores, flatten_labels,
avg_factor=num_pos + num_imgs) # cls loss
pos_bbox_preds = flatten_bbox_preds[pos_inds]
pos_bbox_targets = flatten_bbox_targets[pos_inds]
pos_centerness = flatten_centerness[pos_inds]
pos_centerness_targets = self.centerness_target(pos_bbox_targets)
pos_points = flatten_points[pos_inds]
pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
pos_decoded_target_preds = distance2bbox(pos_points,
pos_bbox_targets)
# centerness weighted iou loss
loss_bbox = self.loss_bbox(
pos_decoded_bbox_preds,
pos_decoded_target_preds,
weight=pos_centerness_targets,
avg_factor=pos_centerness_targets.sum()) #llox loss
loss_centerness = self.loss_centerness(pos_centerness,
pos_centerness_targets) #center loss 交叉熵loss
return dict(
loss_cls=loss_cls,
loss_bbox=loss_bbox,
loss_centerness=loss_centerness)
其中生成中心点啊方法,和原文一样:
def centerness_target(self, pos_bbox_targets):
# only calculate pos centerness targets, otherwise there may be nan
left_right = pos_bbox_targets[:, [0, 2]]
top_bottom = pos_bbox_targets[:, [1, 3]]
centerness_targets = (
left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
return torch.sqrt(centerness_targets)
具体loss参考mmdet
focalloss+IOUloss+ CEloss
bbox_head=dict(
type='FCOSHead',
num_classes=81,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)))