[日常] 关于 Focal Loss(附实现代码)

最近一直在做人脸表情相关的方向,这个领域的 DataSet 数量不大,而且往往存在正负样本不均衡的问题。
一般来说,解决正负样本数量不均衡问题有两个途径:

  1. 设计采样策略,一般都是对数量少的样本进行重采样
  2. 设计 Loss,一般都是对不同类别样本进行权重赋值

我两种策略都使用过,本文讲的是第二种策略中的 Focal Loss。

目录

  1. 理论分析
  2. 源码讲解
  3. 实战使用

1. 理论分析

Focal Loss 是 Kaiming He 和 RBG 在 2017 年的 “Focal Loss for Dense Object Detection” 论文中所提出的一种新的 Loss Function,Focal Loss 主要是为了解决样本类别不均衡问题(也有人说实际上也是解决了 gradient 被 easy example dominant 的问题)。

[日常] 关于 Focal Loss(附实现代码)_第1张图片

网上已经有很多很好的见解了,我就不瞎说了,大家可以看看下面的一些文章:

  • 如何评价Kaiming的Focal Loss for Dense Object Detection?
  • Focal Loss
  • 何恺明大神的「Focal Loss」,如何更好地理解?
  • Focal Loss理解
  • Focal Loss论文阅读 - Focal Loss for Dense Object Detection
  • Focal loss论文详解

2.源码讲解

import torch
import torch.nn as nn
import torch.nn.functional as F

# 针对二分类任务的 Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        self.alpha = torch.tensor(alpha).cuda()
        self.gamma = gamma
        self.size_average = size_average

    def forward(self, pred, target):
        # 如果模型最后没有 nn.Sigmoid(),那么这里就需要对预测结果计算一次 Sigmoid 操作
        # pred = nn.Sigmoid()(pred)

        # 展开 pred 和 target,此时 pred.size = target.size = (BatchSize,1) 
        pred = pred.view(-1,1)
        target = target.view(-1,1)

		# 此处将预测样本为正负的概率都计算出来,此时 pred.size = (BatchSize,2)
        pred = torch.cat((1-pred,pred),dim=1)

		# 根据 target 生成 mask,即根据 ground truth 选择所需概率
		# 用大白话讲就是:
		# 当标签为 1 时,我们就将模型预测该样本为正类的概率代入公式中进行计算
		# 当标签为 0 时,我们就将模型预测该样本为负类的概率代入公式中进行计算
        class_mask = torch.zeros(pred.shape[0],pred.shape[1]).cuda()
        # 这里的 scatter_ 操作不常用,其函数原型为:
        # scatter_(dim,index,src)->Tensor
        # Writes all values from the tensor src into self at the indices specified in the index tensor. 
        # For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.
        class_mask.scatter_(1, target.view(-1, 1).long(), 1.)

		# 利用 mask 将所需概率值挑选出来
        probs = (pred * class_mask).sum(dim=1).view(-1,1)
        probs = probs.clamp(min=0.0001,max=1.0)

        # 计算概率的 log 值
        log_p = probs.log()

		# 根据论文中所述,对 alpha 进行设置(该参数用于调整正负样本数量不均衡带来的问题)
		alpha = torch.ones(pred.shape[0],pred.shape[1]).cuda()
		alpha[:,0] = alpha[:,0] * (1-self.alpha)
		alpha[:,1] = alpha[:,1] * self.alpha
        alpha = (alpha * class_mask).sum(dim=1).view(-1,1)
        
        # 根据 Focal Loss 的公式计算 Loss
        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
 
 		# Loss Function的常规操作,mean 与 sum 的区别不大,相当于学习率设置不一样而已
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()

        return loss

# 针对 Multi-Label 任务的 Focal Loss
class FocalLoss_MultiLabel(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, size_average=True):
        super(FocalLoss_MultiLabel, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.size_average = size_average

    def forward(self, pred, target):
        criterion = FocalLoss(self.alpha,self.gamma,self.size_average)
        loss = torch.zeros(1,target.shape[1]).cuda()

		# 对每个 Label 计算一次 Focal Loss
        for label in range(target.shape[1]):
            batch_loss = criterion(pred[:,label],target[:,label])
            loss[0,label] = batch_loss.mean()

        # Loss Function的常规操作
        if self.size_average:
            loss = loss.mean()
        else:
            loss = loss.sum()
             
        return loss

更新:
编写针对多分类任务的 CELoss 和 Focal Loss,可通过 use_alpha 参数决定是否使用 α \alpha α 参数,并解决之前版本中所出现的 Loss变为 nan 的 bug(原因出自 log 操作,当对过小的数值进行 log 操作,返回值将变为 nan)。

# 针对多分类任务的 CELoss 和 Focal Loss
import torch
import torch.nn as nn
import torch.nn.functional as F

class CELoss(nn.Module):
    def __init__(self, class_num, alpha=None, use_alpha=False, size_average=True):
        super(CELoss, self).__init__()
        self.class_num = class_num
        self.alpha = alpha
        if use_alpha:
            self.alpha = torch.tensor(alpha).cuda()

        self.softmax = nn.Softmax(dim=1)
        self.use_alpha = use_alpha
        self.size_average = size_average

    def forward(self, pred, target):
        prob = self.softmax(pred.view(-1,self.class_num))
        prob = prob.clamp(min=0.0001,max=1.0)
        
        target_ = torch.zeros(target.size(0),self.class_num).cuda()
        target_.scatter_(1, target.view(-1, 1).long(), 1.)
        
        if self.use_alpha:
            batch_loss = - self.alpha.double() * prob.log().double() * target_.double()
        else:
            batch_loss = - prob.log().double() * target_.double()
        
        batch_loss = batch_loss.sum(dim=1)

        # print(prob[0],target[0],target_[0],batch_loss[0])
        # print('--')

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()

        return loss

class FocalLoss(nn.Module):
    def __init__(self, class_num, alpha=None, gamma=2, use_alpha=False, size_average=True):
        super(FocalLoss, self).__init__()
        self.class_num = class_num
        self.alpha = alpha
        self.gamma = gamma
        if use_alpha:
            self.alpha = torch.tensor(alpha).cuda()

        self.softmax = nn.Softmax(dim=1)
        self.use_alpha = use_alpha
        self.size_average = size_average

    def forward(self, pred, target):
        prob = self.softmax(pred.view(-1,self.class_num))
        prob = prob.clamp(min=0.0001,max=1.0)
        
        target_ = torch.zeros(target.size(0),self.class_num).cuda()
        target_.scatter_(1, target.view(-1, 1).long(), 1.)
        
        if self.use_alpha:
            batch_loss = - self.alpha.double() * torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()
        else:
            batch_loss = - torch.pow(1-prob,self.gamma).double() * prob.log().double() * target_.double()

        batch_loss = batch_loss.sum(dim=1)

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()

        return loss

注意:一定要对所求概率进行 clamp 操作,不然当某一概率过小时,进行 log 操作,会使得 loss 变为 nan!!!

3. 实战使用

最近在 RAF DataSet (Basic 部分)上尝试使用了这些 Loss Function,其中,使用的模型为 ResNet-18,输入图像尺度为 112 ∗ 112 112 * 112 112112

Acc1_avg ACC2_avg
CrossEntropy Loss (官方) 82.92 75.98
CELoss (no Alpha) 83.60 76.20
CELoss (Alpha) 83.84 76.21
Focal Loss (no Alpha) 82.54 74.17
Focal Loss (Alpha) 83.05 75.87

其中,alpha 为 numClass 维向量,计算公式为 α i = 1 − ( n u m O f C l a s s i / n u m O f A l l ) \alpha_i=1-(numOfClass_i/numOfAll) αi=1(numOfClassi/numOfAll);Acc1_avg 为所有类别 Acc1 指标的平均值,计算公式为 A c c 1 i = n p . s u m ( p r e d = = t a r g e t ) / p r e d . s h a p e [ 0 ] Acc1_i=np.sum(pred==target)/pred.shape[0] Acc1i=np.sum(pred==target)/pred.shape[0];Acc2_avg 为所有类别 Acc2 指标的平均值,计算公式为 A c c 2 i = n p . s u m ( ( p r e d = = i ) ∗ ( t a r g e t = = i ) ) / n p . s u m ( t a r g e t = = i ) Acc2_i=np.sum((pred==i)*(target==i))/np.sum(target==i) Acc2i=np.sum((pred==i)(target==i))/np.sum(target==i)

参考资料:

  • TORCH.TENSOR
  • Pytorch scatter_ 理解轴的含义

如果你看到了这篇文章的最后,并且觉得有帮助的话,麻烦你花几秒钟时间点个赞,或者受累在评论中指出我的错误。谢谢!

作者信息:
知乎:没头脑
LeetCode:Tao Pu
CSDN:Code_Mart
Github:Bojack-want-drink

你可能感兴趣的:(Paper)