pytorch学习笔记 | Focal loss的原理与pytorch实现

Focal 原理简述

Focal loss是一个针对单阶段物体检测任务中正负样本不均衡而提出来的损失函数,论文地址来自arxiv

数学定义

先放focal loss(FL)和cross entropy(CE)两个函数的数学定义。
其中 p 为概率,而 y 为 0 或 1 的标签。
pytorch学习笔记 | Focal loss的原理与pytorch实现_第1张图片
可以看到focal loss的设计很简单明了,就是在标准交叉熵损失函数的引入一个因子 ( 1 − p t ) λ (1 - p_t)^\lambda (1pt)λ λ = 0 \lambda= 0 λ=0时,损失函数就是标准交叉熵 。

损失函数意义

focal loss 称为焦点损失函数,通过改进标准的二元交叉熵损失函数来控制对正负样本的训练,为了解决在one-stage目标检测中正负样本严重不均衡的一种策略。该损失函数的设计思想类似于boosting,降低容易分类的样本对损失函数的影响,注重较难分类的样本的训练。

在常规的交叉熵函数的基础上,添加一个系数项,其影响从下图曲线来看可知:

  • 当样本的预测分数较高( p t p_t pt较大,指的是模型判断正确的概率较大)时,其计算所得的loss将变小,这一部分样本视为分类较好的数据,我们降低其在总体损失值中的比重;
  • 较难训练的样本则计算得到更大的loss值,模型将着重针对这些样本进行训练和梯度更新。
    pytorch学习笔记 | Focal loss的原理与pytorch实现_第2张图片
    进一步探讨,当我们考虑类别的比重不相同时,我们可以给各个类别添加一个权重常数 α \alpha α,比如想使正样本初始权重为0.8,负样本就为0.2,那么可以令 α = 0.8 \alpha = 0.8 α=0.8,然后该权重常数乘以对应类别的交叉熵计算中得以生效。这样就能够平衡正负样本的重要性。但是要解决简单分类和困难分类样本的问题则需要依赖 λ, λ越大,损失值计算结果越小,这能够实现对容易样本降低权重的平滑调节。对于物体检测,实验发现 λ=2时最优。
    个人认为该损失函数的设计思想可以应用于其他同样有样本不均衡特点的分类任务。

Pytorch 实现

实现思想很简单,就是先利用input和target计算出因子项 ( 1 − p t ) λ (1 - p_t)^\lambda (1pt)λ,然后乘以标准交叉熵即可。

import torch
import torch.nn as nn
import torch.nn.functional as F 
import config as cfg 
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self,):
        super(FocalLoss, self).__init__()
        self.device = torch.device("cuda:" + str(cfg.DEVICE_ID) if torch.cuda.is_available() else "cpu")

    def forward(self, inputs, targets,gamma=2, focal_loss_alpha=0.8):        
    	# 计算正负样本权重
        alpha_factor = torch.ones(targets.shape) * focal_loss_alpha
        alpha_factor = torch.where(torch.eq(targets, 1), alpha_factor, 1. - alpha_factor)
        # 计算因子项
        focal_weight = torch.where(torch.eq(targets, 1), 1. - inputs, inputs)
        # 得到最终的权重
        focal_weight = alpha_factor * torch.pow(focal_weight, focal_loss_gamma)
        targets = targets.type(torch.FloatTensor) 
        # 计算标准交叉熵
        bce = -(targets * torch.log(inputs) + (1. - targets) * torch.log(1. - inputs))
        # focal loss 
        cls_loss = focal_weight * bce
        return cls_loss.sum()

如果你要在GPU上跑,那么你可以尝试以下代码。

import torch
import torch.nn as nn
import torch.nn.functional as F 
import config as cfg 
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self,):
        super(FocalLoss, self).__init__()
        self.device = torch.device("cuda:" + str(cfg.DEVICE_ID) if torch.cuda.is_available() else "cpu")

    def forward(self, inputs, targets):        
        gpu_targets = targets.cuda()
        alpha_factor = torch.ones(gpu_targets.shape).cuda() * cfg.focal_loss_alpha
        alpha_factor = torch.where(torch.eq(gpu_targets, 1), alpha_factor, 1. - alpha_factor)
        focal_weight = torch.where(torch.eq(gpu_targets, 1), 1. - inputs, inputs)
        focal_weight = alpha_factor * torch.pow(focal_weight, cfg.focal_loss_gamma)
        targets = targets.type(torch.FloatTensor)
        inputs = inputs.cuda()
        targets = targets.cuda()
        bce = F.binary_cross_entropy(inputs, targets)
        focal_weight = focal_weight.cuda()
        cls_loss = focal_weight * bce
        return cls_loss.sum()

你可能感兴趣的:(AI,CV,DL,pytorch)