(大佬)睿智的目标检测9——Focal loss详解及其实现

原文链接:https://blog.csdn.net/weixin_44791964/article/details/102853782

睿智的目标检测9——Focal loss详解及其实现

  • 2020/5/13日更新
  • 学习前言
  • 什么是Focal loss
    • 控制正负样本的权重
    • 控制容易分类和难分类样本的权重
    • 两种权重控制方法合并
  • 实现方式

2020/5/13日更新

之前的focal_loss没有考虑存在目标的框中的其它种类标签为0,已经修改代码!请各位注意啦!

学习前言

最近觉得yolo3的训练效果还可以优化,除去改变特征提取网络外,还可以改变其loss的组成。
在这里插入图片描述

什么是Focal loss

Focal loss是何恺明大神提出的一种新的loss计算方案。
其具有两个重要的特点。

1、控制正负样本的权重

2、控制容易分类和难分类样本的权重

正负样本的概念如下:
一张图像可能生成成千上万的候选框,但是其中只有很少一部分是包含目标的的,有目标的就是正样本,没有目标的就是负样本。

容易分类和难分类样本的概念如下:
假设存在一个二分类,样本1属于类别1的pt=0.9,样本2属于类别1的pt=0.6,显然前者更可能是类别1,其就是容易分类的样本;后者有可能是类别1,所以其为难分类样本。

如何实现权重控制呢,请往下看:

控制正负样本的权重

如下是常用的交叉熵loss,以二分类为例:
在这里插入图片描述
我们可以利用如下Pt简化交叉熵loss。
在这里插入图片描述
此时:
在这里插入图片描述
想要降低负样本的影响,可以在常规的损失函数前增加一个系数αt。与Pt类似,当label=1的时候,αt=α;当label=otherwise的时候,αt=1 - α,a的范围也是0到1。此时我们便可以通过设置α实现控制正负样本对loss的贡献在这里插入图片描述
其中:
在这里插入图片描述
分解开就是:
在这里插入图片描述

控制容易分类和难分类样本的权重

按照刚才的思路,一个二分类,样本1属于类别1的pt=0.9,样本2属于类别1的pt=0.6,也就是 是某个类的概率越大,其越容易分类 所以利用1-Pt就可以计算出其属于容易分类或者难分类。
具体实现方式如下。
在这里插入图片描述
其中:
( 1 − p t ) γ ( 1 − p t ) γ ( 1 − p t ) γ (1−pt)γ(1−pt)γ (1-p_{t})^{γ} (1pt)γ(1pt)γ(1pt)γ(1pt)γ
称为调制系数(modulating factor)

1、当pt趋于0的时候,调制系数趋于1,对于总的loss的贡献很大。当pt趋于1的时候,调制系数趋于0,也就是对于总的loss的贡献很小。
2、当γ=0的时候,focal loss就是传统的交叉熵损失,可以通过调整γ实现调制系数的改变。

两种权重控制方法合并

通过如下公式就可以实现控制正负样本的权重控制容易分类和难分类样本的权重
在这里插入图片描述

实现方式

def focal(alpha=0.25, gamma=2.0):
    def _focal(y_true, y_pred):
        # y_true [batch_size, num_anchor, num_classes+1]
        # y_pred [batch_size, num_anchor, num_classes]
        labels         = y_true[:, :, :-1]
        anchor_state   = y_true[:, :, -1]  # -1 是需要忽略的, 0 是背景, 1 是存在目标
        classification = y_pred
    # 找出存在目标的先验框
    indices_for_object        = backend.where(keras.backend.equal(anchor_state, 1))
    labels_for_object         = backend.gather_nd(labels, indices_for_object)
    classification_for_object = backend.gather_nd(classification, indices_for_object)

    # 计算每一个先验框应该有的权重
    alpha_factor_for_object = keras.backend.ones_like(labels_for_object) * alpha
    alpha_factor_for_object = backend.where(keras.backend.equal(labels_for_object, 1), alpha_factor_for_object, 1 - alpha_factor_for_object)
    focal_weight_for_object = backend.where(keras.backend.equal(labels_for_object, 1), 1 - classification_for_object, classification_for_object)
    focal_weight_for_object = alpha_factor_for_object * focal_weight_for_object ** gamma

    # 将权重乘上所求得的交叉熵
    cls_loss_for_object = focal_weight_for_object * keras.backend.binary_crossentropy(labels_for_object, classification_for_object)

    # 找出实际上为背景的先验框
    indices_for_back        = backend.where(keras.backend.equal(anchor_state, 0))
    labels_for_back         = backend.gather_nd(labels, indices_for_back)
    classification_for_back = backend.gather_nd(classification, indices_for_back)

    # 计算每一个先验框应该有的权重
    alpha_factor_for_back = keras.backend.ones_like(labels_for_back) * (1 - alpha)
    focal_weight_for_back = classification_for_back
    focal_weight_for_back = alpha_factor_for_back * focal_weight_for_back ** gamma

    # 将权重乘上所求得的交叉熵
    cls_loss_for_back = focal_weight_for_back * keras.backend.binary_crossentropy(labels_for_back, classification_for_back)

    # 标准化,实际上是正样本的数量
    normalizer = tf.where(keras.backend.equal(anchor_state, 1))
    normalizer = keras.backend.cast(keras.backend.shape(normalizer)[0], keras.backend.floatx())
    normalizer = keras.backend.maximum(keras.backend.cast_to_floatx(1.0), normalizer)

    # 将所获得的loss除上正样本的数量
    cls_loss_for_object = keras.backend.sum(cls_loss_for_object)
    cls_loss_for_back = keras.backend.sum(cls_loss_for_back)

    # 总的loss
    loss = (cls_loss_for_object + cls_loss_for_back)/normalizer

    return loss
return _focal
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
                                

你可能感兴趣的:(人脸检测)