目标检测之focal loss

  Focal Loss for Dense Object DetectionTsung-Yi Lin Priya Goyal Ross Girshick Kaiming He Piotr Dollar´Facebook AI Research (FAIR),最近出的一篇做目标检测的工作,本质上是个利用loss function解决样本imbalance问题的工作。

Motivation

      作者说是为了探究为什么ssd、yolo之类的one-stage目标检测网络不如rcnn系列的two-stage检测网络框架效果好,提出了原因在于one-stage训练的过程中没有two-stage的RPN过程来做cascade,所以会有大量的简单负样本在分类器的训练过程中导致训练模型结果不好。采用online hard sample mining(OHEM)的方式是比较通用的方法来解决imbalance的问题,作者设计了一个叫做focal loss的loss function来解决同样的问题,可以认为是OHEM的一个替代方案。

基本原理

       先上个网络结构图,目前主要的one-stage检测框架如下,一个FPN网络带不同scale和shape的anchor,然后不同scale的feature map的组合接出两个子网络,一个用来回归bbox的位置,一个用来回归这个位置所对应的物体分类结果用于判定是否是一个误检。上面提到的样本不均衡的问题就是在这个class subnet上出现的,这会导致这个子网络的预测结果不够好,使得整体检测结果下降。 目标检测之focal loss_第1张图片
        所以focal loss就是用来代替这个class subnet的损失函数的,focal loss是交叉熵损失函数(cross entropy,公式2)的变种,
        
        传统的应对样本不均衡的办法就是在这个loss上面加个参数a,通常是样本比例取反,这样来保证数量少的样本获得较大的weight来抵消数量少的影响。focal loss的定义是:
         可以看到在原有的基础上增加了一个预测概率p和超参数r,其中p的存在就是如果这个样本预测的已经很好了(也就是p->1)那么这个样本产生的loss就接近于0,r的作用是对这个接近的速度做控制,把loss画出来如下图:
           目标检测之focal loss_第2张图片
            可以看出r越大,预测越正确的样本loss下降的越快,也就是这部分样本对于loss的贡献就越小。其实这个工作原理到此就差不多了,作者起了个挺好听的名字叫RetinaNet,下面是一些实现细节:
           1. class subnet和regression subnet用了更深的3*3 conv层
           2. focal loss初始化的时候加了bias: b = log((1 π)),其中π = 0.01,用来防止最开始的几轮迭代时梯度不稳定
   3.8块GPU minibatch=16(但是对比实验的OHEM的batch size最小也是128,不知道这个会不会有什么影响)

实验结果

  对不同的超参数做了对比,也跟OHEM以及state-of-art在COCO上做了对比,结果如下:

    

目标检测之focal loss_第3张图片

     Table1对比了不同的超参数a和r对结果的影响,选了个最好的组合,并且在这个组合下对比了和OHEM的效果,从结果看提升还是很明显的。

目标检测之focal loss_第4张图片

     table2对比了和目前主流方法的对比,值得注意的是table2的结果比table1还要高一些,作者解释说是多训了50%的轮数。。。

整体总结

  工作的idea是比较简单的,paper发的看上去也比较着急,实验不太solid,里面文字错误也有一些,应该是个发出来占坑的paper,一些比较本质的问题其实没有回答的很好。focal loss的本质还是对不同的样本做不同的weight,与OHEM是不同的weight分配方式。作者论文里解释说OHEM单纯的去掉大部分的简单负样本,不如focal loss把所有样本都考虑进来要好,这个逻辑不太成立,focal loss其实也是希望把简单负样本的weight降低,为什么降成0就不够好,这个事情没能通过实验加以证明(比如在focal loss的基础上把排名靠后的样本weight置成0)。和OHEM相比其合理性可能在于一个soft的方式来分配weight通常是要更合理更可tuning的方式,同样的思想应该也可以用到metric-learning之类的需要hard negtive sampling的方法中。

你可能感兴趣的:(目标检测之focal loss)