【trick 4】Focus Loss —— 解决one-stage目标检测中正负样本不均衡的问题

来自:

https://arxiv.org/abs/1708.02002

目录

  • 一、提出背景
  • 二、设计思路
  • 三、总结优缺点
  • 四、PyTorch实现
  • Reference

一、提出背景

目前目标检测的框架一般分为两种:基于候选区域的two-stage的检测框架(比如fast r-cnn系列),基于回归的one-stage的检测框架(yolo,ssd这种),two-stage的速度较慢但效果好,one-stage的速度快但效果差一些。

对于one-stage的检测器准确率不高的问题,论文作者给出了解释:由于正负样本不均衡的问题(感觉理解成简单-难分样本不均衡比较好)。 什么意思呢,就是说one-stage中能够匹配到目标的候选框(正样本)个数一般只用十几个或几十个,而没匹配到的候选框(负样本)大概有 1 0 4 − 1 0 5 10^4 - 10^5 104105个。而负样本大多数都是简单易分的,对训练起不到什么作用,但是数量太多会淹没掉少数但是对训练有帮助的困难样本。

其实two-stage的目标检测也存在着正负样本不均衡的问题,我们都知道,two-stage目标检测是将整个检测过程分成两部分,第一部分先选出一些候选框(如faster rcnn 的rpn),大概2000个左右,在第二阶段再进行筛选,虽然这时正负样本也是存在不均衡的,但是(十几个或几十个:2000) 相对(十几个或几十个: 1 0 4 − 1 0 5 10^4 - 10^5 104105),这时候好了太多了,所以我们的Focal Loss主要针对的是one-stage的目标检测算法。

那么正负样本不均衡,会带来什么问题呢?

  1. 训练效率低下。 training is inefficient as most locations are easy negatives that contribute no useful learning signal;
  2. 模型精度变低。 过多的负样本会主导训练,使模型退化。en masse,the easy negatives can overwhelm training and lead to degenerate models.

针对上述问题,一般的解决方法是难例挖掘(hard negative mining),不过该论文提出了一种新型的Loss函数,试着解决这个问题。

二、设计思路

Focus Loss设计的一个主要的思路就是:希望那些hard examples对损失的贡献变大,使网络更倾向于从这些样本上学习。防止由于easy examples过多,主导整个损失函数。

作者先以二分类为例进行说明:
先看看我们最常用的交叉熵损失函数:

C E ( y , p ) = { − l o g ( p ) y = 1 − l o g ( 1 − p ) 其 他 CE(y,p)=\begin{cases} -log(p) & y=1 \\ -log(1-p) & 其他 \end{cases} CE(y,p)={ log(p)log(1p)y=1

其中 y为真实标签,p为预测概率。
为了简便也可以写为:

p t = { p y = 1 1 − p 其 他 p_t=\begin{cases} p & y=1 \\ 1-p & 其他 \end{cases} pt={ p1py=1 and rewirte C E ( y , p ) = C E ( p t ) = − l o g ( p t ) CE(y,p)=CE(p_t)=-log(p_t) CE(y,p)=CE(pt)=log(pt)

要对类别不均衡问题对loss的贡献进行一个控制,即加上一个控制权重即可,最初作者的想法即如下这样,对于属于少数类别的样本,增大α即可:

C E ( p t ) = − α t l o g ( p t ) CE(p_t)=- \alpha_t log(p_t) CE(pt)=αtlog(pt)
α t = { α y = 1 1 − α 其 他 ; 其 中 α ∈ [ 0 , 1 ] \alpha_t=\begin{cases} \alpha & y=1 \\ 1-\alpha & 其他 \end{cases}; 其中\alpha\in[0, 1] αt={ α1αy=1;α[0,1]

注意:这里的 α t \alpha_t αt并不是正负样本的比例,而是一个超参数,用来平衡正负样本的权重。

但是上式只是解决了正负样本之间的平衡问题,并没有区分易分/难分样本,因此就有了下面的公式:

F L ( p t ) = − ( 1 − p t ) γ l o g ( p t ) FL(p_t)=- (1-p_t)^\gamma log(p_t) FL(pt)=(1pt)γlog(pt)

分析:

  1. 简单样本: 容易预测正确。当y=1(正), p p p->1, p t p_t pt ->1, ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ->0, loss小;当y为其他时(负), p p p->0, p t p_t pt->1, ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ->0, loss小;所以综合来看,当样本为简单样本的时候,损失会比原来的损失小很多倍。
  2. 复杂样本: 容易预测错误。当y=1(正), p p p->0, p t p_t pt ->0, ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ->1,loss下降一点点(几乎不变);当y为其他时(负), p p p->1, p t p_t pt->0, ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ->1,loss下降一点点(几乎不变)。所以综合来看,当样本为复杂样本时,损失和原来的损失差不多,不会小太多。
    举例:前面4行是简单样本(数量很多),使用FL损失函数使其损失值下降了很多倍(相比CE损失函数);而后面两个复杂样本(数量较少),使用FL损失函数后损失值只下降了很少倍。
    【trick 4】Focus Loss —— 解决one-stage目标检测中正负样本不均衡的问题_第1张图片

所以 γ \gamma γ参数是用来区分易分/难分样本的。它可以通过降低简单样本(数量多)的损失权重,使损失函数更加专注于困难样本(数量),防止简单样本主导整个损失函数。

综合两个方面,最终的损失函数为:

F L ( p t ) = − α t ( 1 − p t ) γ l o g ( p t ) FL(p_t)=- \alpha_t (1-p_t)^\gamma log(p_t) FL(pt)=αt(1pt)γlog(pt)
p t = { p y = 1 1 − p 其 他 p_t=\begin{cases} p & y=1 \\ 1-p & 其他 \end{cases} pt={ p1py=1
α t = { α y = 1 ( 正 样 本 ) 1 − α 其 他 ( 负 样 本 ) ; 其 中 α ∈ [ 0 , 1 ] \alpha_t=\begin{cases} \alpha & y=1(正样本) \\ 1-\alpha & 其他(负样本) \end{cases}; 其中\alpha\in[0, 1] αt={ α1αy=1()();α[0,1]

其中 α t \alpha_t αt来协调正负样本之间的平衡, γ \gamma γ来降低简单样本的权重,使损失函数更关注困难样本。

举例说明:
【trick 4】Focus Loss —— 解决one-stage目标检测中正负样本不均衡的问题_第2张图片
如上图,横坐标代表 p t p_t pt,纵坐标表示各种样本所占的loss权重。对于正样本,我们希望 p p p越接近1越好,也就是 p t p_t pt越接近1越好;对于负样本,我们希望 p p p越接近0越好,也就是 p t p_t pt越接近1越好。所以不管是正样本还是负样本,我们总是希望他预测得到的 p t p_t pt越大越好。如上图所示, p t ∈ [ 0.6 , 1 ] p_t\in[0.6, 1] pt[0.6,1]就是我们预测效果比较好的样本(也就是易分样本)了。
显然可以想象这部分的样本数量很多,所以占比是比较高的(如图中蓝色线区域),我们用 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ来降低易分样本的损失占比 / 损失贡献(如图其他颜色的曲线)。

三、总结优缺点

优点:

  1. 解决了one-stage object detection中图片中正负样本(前景和背景)不均衡的问题;
  2. 降低简单样本的权重,使损失函数更关注困难样本;

缺点:

  1. 模型很容易收到噪声干扰:会将噪声当成复杂样本,使模型过拟合退化;
  2. 模型的初期,数量多的一类可能主导整个loss,所以训练初期可能训练不稳定;
  3. 两个参数 α t \alpha_t αt γ \gamma γ具体的值很难定义,需要自己调参,调的不好可能效果会更差(论文中的 α t \alpha_t αt=0.25, γ \gamma γ=2最好)。

    【trick 4】Focus Loss —— 解决one-stage目标检测中正负样本不均衡的问题_第3张图片

四、PyTorch实现

在yolo_v3_spp的实现代码。

class FocalLoss(nn.Module):
    # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
    def __init__(self, loss_fcn, gamma=2, alpha=0.25):
        super(FocalLoss, self).__init__()
        self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
        self.gamma = gamma   # 参数gamma
        self.alpha = alpha   # 参数alpha
        # reduction: 控制损失输出模式 sum/mean/none 这里定义的交叉熵损失BCE都是mean
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = 'none'  # 不知道这句有什么用?  required to apply FL to each element

    def forward(self, pred, true):
        loss = self.loss_fcn(pred, true)  # 普通BCE Loss
        # p_t = torch.exp(-loss)
        # loss *= self.alpha * (1.000001 - p_t) ** self.gamma  # non-zero power for gradient stability

        # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
        pred_prob = torch.sigmoid(pred)  # prob from logits 如果模型最后没有 nn.Sigmoid(),那么这里就需要对预测结果计算一次 Sigmoid 操作
        # ture=0,p_t=1-p; true=1, p_t=p
        p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
        # ture=0, alpha_factor=1-alpha; true=1,alpha_factor=alpha
        alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
        modulating_factor = (1.0 - p_t) ** self.gamma
        # loss = focus loss(代入公式即可)
        loss *= alpha_factor * modulating_factor

        if self.reduction == 'mean': # 一般是mean
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:  # 'none'
            return loss

Reference

  1. https://blog.csdn.net/Code_Mart/article/details/89736187
  2. https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py

你可能感兴趣的:(深度学习)