<Focal Loss for Dense Object Detection>论文解读

目录

  • 1.简介
  • 2.模型
    • 2.1 二阶段要比单阶段模型效果好本质原因
    • 2.2 模型结构
    • 2.3.focal loss
      • 2.3.1 focal loss公式说明
        • (1) bec loss
        • (2) 控制容易分类/难分类样本的权重
        • (3)控制正负样本的权重
        • (4) focal loss
        • (5) bce vs ce ,即二分类交叉熵 vs  多分类交叉熵
      • 2.3.3 论文其他设定
    • 2.4 消融实验
  • 3.源码详解
  • 4 ref

1.简介

目标识别有两大经典结构: 第一类是以Faster RCNN为代表的二阶段识别方法,这种结构的第一阶段专注于proposal的提取,第二阶段则对提取出的proposal进行分类和精确坐标回归。
二阶段结构准确度较高,但因为第二阶段需要单独对每个proposal进行分类/回归,速度就打了折扣;目标识别的第二类结构是以YOLO和SSD为代表的单阶段结构,它们摒弃了提取proposal的过程,只用一级就完成了识别/回归,虽然速度较快但准确率远远比不上两级结构。那有没有办法在单阶段结构中也能实现较高的准确度呢?Focal Loss就是要解决这个问题。
<Focal Loss for Dense Object Detection>论文解读_第1张图片
这是在coco数据集上的mAP指标, 可以看出要比一些单阶段的例如ssd,还有二阶段fpn faster rcnn都要高。在当时2018年的时候,还是SOTA的。

2.模型

2.1 二阶段要比单阶段模型效果好本质原因

作者认为,单阶段效果比二阶段差的根本原因是类别不均衡
二阶段模型一般在训练过程,第一个阶段筛选出的proposals,这已经过滤掉了大部分的背景bbox,第二个阶段采样过程保持正负样本的一定比例,例如fixed foreground-to-background ratio (1:3), or online hard example mining (OHEM). 这样就保持了前后背景样本的比例平衡问题。
而单阶段的模型,没有proposal,针对所有的候选位置进行采样,这些bbox大约有∼100k 左右。负样本的数量远远大于正样本的数量,造成正负样本的极不均衡。采样过程可以学习二阶段模型,但是这个过程肯定是低效的,因为训练过程还是大部分被`easily classified background主导,所以整体的效果稍差。

而正负样本的极不平衡会造成如下影响:

在计算loss时,负样本数量很多,所以在loss中负样本的比重就很大,然而负样本比较容易分类(easy negatives),所以给loss能提供的有用信息较少。
而正样本是我们最终要得到的检测结果,比较难分类(hard positive),所以提供的loss信息比较重要,但是由于数量少,这些关键的loss很容易被淹没掉。

2.2 模型结构

<Focal Loss for Dense Object Detection>论文解读_第2张图片
模型的结构中规中据
backbone: resnet 50 or 100
neck: fpn
head: dease head ( class + bbox regression)

最大的亮点是在于利用focal loss解决 关于前后背景/简单,难例不均衡问题,从而抑制easy sample,让更多的正负hard sample在loss上起到更大作用,更好的解决样本类别不均衡问题。

2.3.focal loss

很多博客都没有解释清楚,感觉很不明白,所以对照mmdetection里面相关源码进行详细推导和解释,希望能讲解清楚具体什么是focal loss,focal loss具体是怎么计算的,究竟能怎么应用的。

2.3.1 focal loss公式说明

Focal Loss是一种Loss计算方案。其具有两个重要的特点。

1、控制正负样本的权重
2、控制容易分类和难分类样本的权重

正负样本的概念如下:
目标检测本质上是进行密集采样,在一张图像生成成千上万的先验框(或者特征点),将真实框与部分先验框匹配,匹配上的先验框就是正样本,没有匹配上的就是负样本。
难易样本的概念如下:
假设存在一个二分类问题,样本1和样本2均为类别1。网络的预测结果中,样本1属于类别1的概率=0.9,样本2属于类别1的概率=0.6,前者预测的比较准确,是容易分类的样本;后者预测的不够准确,是难分类的样本。

<Focal Loss for Dense Object Detection>论文解读_第3张图片

从图中可以看出,一般样本可以分为四大类:

easy negative:全是背景,比较容易判断的负样本
easy positive:全是物体,比好容易判断的正样本
hard negative:包含部分物体,但大部分为背景,比较难判断的负样本
hard positive:包含部分背景,但大部分为物体,比较难判断的正样本

可以看出hard examples就是在背景和物体过渡的区域,但是由于每张图中的物体较少,也就是正样本比较少,所有这种hard examples也比较少,同时由于负样本有很多,所以easy negative就很多,因此easy examples也就远多于hard examples。
所以说正负样本不均衡可以引起hard-easy样本不均衡,进而使得loss被easy examples的loss所控制,从而使得模型没有一个有效的loss来指导训练,所以最终得到一个不好的模型,所以最后的准确率比较低。
所以我们需要Focal loss来赋予这些hard examples更多权重。

在这里插入图片描述
从上面可知,作者在paper中为了简化,拿二分类问题进行举例。

(1) bec loss

Focal loss是在交叉熵损失函数基础上进行的修改,首先回顾 二分类交叉熵(bce, binary_cross_entropy) 上损失:
<Focal Loss for Dense Object Detection>论文解读_第4张图片
其中:

	y  :表示真实标签值label,二分类的话:如果是前景y=1, 背景y=0
	log :是以e为底数的对数
	p :是预测predict score的 sigmoid 取值,代表对应为前景的概率值

论文里面的写法,是在pt在不同label下的概率,所以写成 CE(p; y) = CE(pt) = − log(pt),跟我上面详细写的内容一样其实。比较精简。
(论文bce)
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

可见普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。此时的损失函数在大量简单样本的迭代过程中比较缓慢且可能无法优化至最优。

那么Focal loss是怎么改进的呢?

(2) 控制容易分类/难分类样本的权重

在这里插入图片描述

首先在原有的基础上加了一个因子 gamma,其中 gamma>0 用来控制难易样本的贡献分配。使得减少易分类样本的影响,使得更关注于困难的、错分的样本。
例如 gamma=2,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的gamma次方就会很小,这时损失函数值就变得更小。
而正样本的预测概率如果是为0.3,其损失相对很大。
对于负类样本而言同样,预测0.1的结果应当远比预测0.7的样本损失值要小得多。
对于预测概率为0.5时,损失只减少了0.25倍,所以更加关注于这种难以区分的样本。
这样减少了简单样本的影响,大量预测概率很小的easy negative 样本叠加起来后的效应才可能减弱。hard 样本的作用才能叠加增强,从而更好起到作用。

只添加alpha虽然可以平衡正负样本的重要性,但是无法解决简单与困难样本的问题。

(3)控制正负样本的权重

此外,加入平衡因子alpha用来控制正负样本的贡献。可以平衡正负样本本身的比例不均:文中alpha取0.25,即正样本要比负样本占比小,这是因为负例易分。

在这里插入图片描述

gamma调节简单样本权重降低的速率,当gamma为0时即为交叉熵损失函数,当gamma增加时,调整因子的影响也在增加。实验发现gamma为2是最优。

<Focal Loss for Dense Object Detection>论文解读_第5张图片

(4) focal loss

综合对于正负样本/难易样本的调节因子,最终的focal loss如下所示:

<Focal Loss for Dense Object Detection>论文解读_第6张图片

注意:

		论文里面,对应的focal loss写法比较精简
		其实就是在y取不同值的时候(label = 0 or 1), pt, at也取到不同值对应的变体的总和。
		其中如果
				y = 1 ---> at = a, pt = 1-p
				y = 0 ---> at = 1-a, pt = p

(论文里面focal loss表示, 正如上所示)
在这里插入图片描述

调参重点:
gamma : 难例权重,越大越关注难例。gamma占主导地位。随着gamma的增大,alpha要相应的减小。
alpha:正负样本权重,越大越关注正样本。在gamma增加的时候,alpha要适当减小。

(5) bce vs ce ,即二分类交叉熵 vs  多分类交叉熵

论文里面一直使用都是二分类交叉熵(bce)来说明问题,但是实际的检测目标都是多分类交叉熵(ce),那么二者什么区别?具体计算的时候又是如何计算多分类focal loss的呢?

首先一句话概括两者区别,BCE用于“是不是”问题,例如LR输出概率,明天下雨or不下雨的概率;CE用于“是哪个”问题,比如多分类问题

BCE
BCE+sigmoid在很多地方都有用到,例如逻辑回归(LR)、点击率预测、多标签学习(Multi-label learning)等等。其通常是配合sigmoid函数使用,形式如下:
<Focal Loss for Dense Object Detection>论文解读_第7张图片
CE
CE+softmax是多分类任务里最常使用到的损失形式了,形式如下:

<Focal Loss for Dense Object Detection>论文解读_第8张图片

总结:

  • 仅看损失形式上,BCE好像既考虑了正样本损失又考虑了负样本的损失,而CE只考虑了正样本损失。 但其实,二者所使用的激活函数不同,前者使用sigmoid,后者使用softmax,softmax其形式上本身就考虑负类的信息在里面。
  • CE在二分类情况下本质上和BCE没有太大的区别,但可能优化上有细微不同。

参照刚才推导的公式,可以将二分类问题推广到多分类问题损失
<Focal Loss for Dense Object Detection>论文解读_第9张图片
focal loss 的多分类损失计算,采用的是bce loss的改进。
其中针对多类别标签y采用one-hot的格式(类似于[0,0,1], 其中第几位为1则是表示类别是几,例如前面这个表示为3), 其中输出为1则为正样本,输出为0则全部视为负样本,将所有类别的和相加得到单个输出的交叉熵。从而可以转化成二分类问题计算多分类问题的交叉熵。

例如 
3类别的检测,假设存在对于某个anchor预测,预测值的输出概率是p_sigmoid=[0.1, 0.2, 0.3] 。
gt是类别3,写成one-hot形式是 y=[0,0,1],focal loss设置值为a = 0.3, r =2。

y=[0,0,1]
p_sigmoid=[0.1, 0.2, 0.3]
a = 0.3, r =2

focal loss = 
-[0+(1-0.3)x0.1^2 x log(1-0.1) +
 0+ (1-0.3) x 0.2 ^2 x log(1-0.2) + 
 0.3 x 0.7 ^ 2 x log(0.3)+0] = -[-0.00073752 -0.00624802 -0.176984]=0.18396954528231524

代码实现参见后面源码分析

2.3.3 论文其他设定

(1) 并不是对于所有的anchor都计算loss,只是对于存在gt的所有anchor计算loss
<Focal Loss for Dense Object Detection>论文解读_第10张图片

(2)  初始化
<Focal Loss for Dense Object Detection>论文解读_第11张图片

在模型运行初始阶段,为了训练稳定性,设定了一个预先值π,即正样本的概率一般取到π=0.01
<Focal Loss for Dense Object Detection>论文解读_第12张图片
最后一层的卷积bias b稍有不同

2.4 消融实验

<Focal Loss for Dense Object Detection>论文解读_第13张图片

(a) 单独调alpha,在0.75最优
(b) alpha+gamma : alpha降低到最小,gamma较大最好。关注negtivate hard example最好。>
© 调整anchor scale or aspect。这个也不是anchor越大越多最好
(d) OHEM vs FL ,FL更好一些
(e) input size尺度, backbone大小影响

<Focal Loss for Dense Object Detection>论文解读_第14张图片
正样本和负样本的累积分布函数(CDF)如图4所示。如果我们观察正样本损失(左),我们会发现CDF看起来,随着gamma的增加,变化其实并不大,说明gamma对于正样本难例的提升作用较小。

gamma对负样本的影响截然不同。gamma=0时,正CDF和负CDF相当相像的然而,随着gamma的增加模型权重更多的关注在较难的负样本上。在里面事实上,当gamma=2(我们的默认设置)时loss损失很少来自于背景样本。

正如可能的那样可见,FL可以有效地降低easy negetive sample的影响,将所有注意力集中在hard negative examples.上。

3.源码详解

详细结构代码串讲内容参见:

下面针对Focal loss相关源码详解一下:
(1) 计算的时候,MMDetection 提供了 py 和 cuda 版本,py 版本如下所示:

"""
    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the
            number of classes.所有预测输出的概率。
        target (torch.Tensor): The learning label of the prediction.真实的label,注意是one-hot格式的。
        weight (torch.Tensor, optional): Sample-wise loss weight.跟loss进行相乘的权重。
        gamma (float, optional): The gamma for calculating the modulating
            factor. Defaults to 2.0. focal loass的 参数,控制难易样本。
        alpha (float, optional): A balanced form for Focal Loss.
            Defaults to 0.25. focal loass的 参数,控制正负样本比例。
        reduction (str, optional): The method used to reduce the loss into
            a scalar. Defaults to 'mean'.
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
"""
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    if weight is not None:
        if weight.shape != loss.shape:
            if weight.size(0) == loss.size(0):
                # For most cases, weight is of shape (num_priors, ),
                #  which means it does not have the second axis num_class
                weight = weight.view(-1, 1)
            else:
                # Sometimes, weight per anchor per class is also needed. e.g.
                #  in FSAF. But it may be flattened of shape
                #  (num_priors x num_class, ), while loss is still of shape
                #  (num_priors, num_class).
                assert weight.numel() == loss.numel()
                weight = weight.view(loss.size(0), -1)
        assert weight.ndim == loss.ndim
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

注意:

  1. 这是计算所有anchor输出的loss
  2. 论文中讲解的时候使用的是二分类的交叉上,这个计算都时候使用的是多分类交叉熵

(2) 通过计算实例进行相关比较

3类别的检测,假设存在对于某个anchor预测,预测值的输出概率是p_sigmoid=[0.1, 0.2, 0.3] 。
gt是类别3,写成one-hot形式是 y=[0,0,1],focal loss设置值为a = 0.3, r =2。

###### 已知条件
y=[0,0,1]
p_sigmoid=[0.1, 0.2, 0.3]
a = 0.3, r =2

###### 手动计算(具体公式参照前面内容)
focal loss = 
-[0+(1-0.3)x0.1^2 x log(1-0.1) +
 0+ (1-0.3) x 0.2 ^2 x log(1-0.2) + 
 0.3 x 0.7 ^ 2 x log(0.3)+0] = 
 -[-0.00073752 -0.00624802 -0.176984]=0.18396954528231524
###### 代码实现: 对比纯手工计算  vs focal loss源码tensor计算
def test_fl():
     import numpy as np
     ### compute by hand
     p_sigmoid = np.array([0.1,0.2,0.3])
     label = np.array([0.0, 0.0, 1.0])
     a = 0.3
     r = 2
     # print(np.log(1-p_sigmoid))
     cmp = a*label*np.power((1-p_sigmoid),r)*np.log(p_sigmoid)+\
            (1-a)*(1-label)*np.power(p_sigmoid,r)*np.log(1-p_sigmoid)
     print("-----> compute loss:",cmp, -np.sum(cmp))
     
     ### compute by focal loss
     p_sigmoid = torch.tensor([0.1,0.2,0.3])
     label = torch.tensor([0.0, 0.0, 1.0])
     a = torch.tensor(0.3)
     r = torch.tensor(2)
     pred_sigmoid = p_sigmoid
     target = label.type_as(p_sigmoid)
     pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
     focal_weight = (a * target + (1 - a) *
                     (1 - target)) * pt.pow(r)
     loss = F.binary_cross_entropy(
          pred_sigmoid, target, reduction='none')*focal_weight
     print("-----> compute loss:", loss, loss.sum())

结果是一致的
在这里插入图片描述

4 ref

BCE和CE的区别
轻松掌握 MMDetection 中常用算法(一):RetinaNet 及配置详解

你可能感兴趣的:(目标检测paper精读,目标检测,人工智能,深度学习)