mmdetection添加focal loss

最近刚打完一个交通标志检测的比赛,最终排名为18/1073,虽然成绩不好,但是收获了很多,之前打比赛一直使用caffe和detectron框架,这次初步尝试了一下商汤开源的mmdetection,发现模块化的东西很多,用起来也很顺手。由于网上资料不多,所以记录一下一些训练的方法。

1.focal loss

mmdet提供的configs文件里只在retinanet中打开了focal loss的功能,主要是因为一阶段算法使用密集anchor一步回归的方法,其中正负样本非常不均衡,所以focal loss损失函数主要解决了正负样本不均衡以及难分易分样本权值一样的问题(这里与OHEM的区别就是OHEM主要集中在难负样本上,而不考虑易分样本)。

2.解锁所有模型的focal loss功能

这里只激活rpn阶段的focal loss,因为rcnn阶段,rpn已经初步过滤了样本,可以采用OHEM策略。
在mmdet/models/anchor_heads/anchor_head.py的AnchorHead类中第44行use_focal_loss设置为True即可激活所有模型的focal Loss。
然后在config训练文件中的train_cfg的rpn最后加入:

smoothl1_beta=0.11,
gamma=2.0,
alpha=0.25,
allowed_border=-1,
pos_weight=-1,
debug=False

3.mmdet中只使用mask_rcnn来检测

首先在model的dict中删去mask_roi_extractor和mask_head字段及其附属内容,接着在train_cfg的dict中删除所有的mask_size=28,最后在data中把所有的with_mask=True改为with_mask=False即可

你可能感兴趣的:(mmdetection添加focal loss)