Focal loss的Pytorch实现

1.Focal Loss介绍

Focal loss是在CrossEntropy基础上进行改进的,主要解决了训练中正负样本和简单困难样本重要性不均衡的问题。首次提出是在论文《Focal Loss for Dense Object Detection》中,作者Kaiming He的出发点是想解决样本的类别不均衡导致的one-stage和two-stage的表现差异问题。

样本的不平衡将导致两个问题:1.训练难度上升,因为大部分的样本都是简单样本,很难从中学习到有用的信息;2.大量的某一类样本会使模型的学习能力下降。Focal loss通过在内部加权来解决类别不平衡问题:简单样本降低权重,正负样本按比例分配权重。

 

2.损失函数公式

Focal loss是在交叉熵损失函数基础上进行的修改,首先回顾二分类交叉上损失:

              

是经过激活函数的输出,所以在0-1之间。可见普通的交叉熵对于正样本而言,输出概率越大损失越小。对于负样本而言,输出概率越小则损失越小。

为了提高对更困难样本的关注,作者加入因子gamma。gamma=0时函数等价于交叉熵loss,gamma>0时对于易分类样本的损失将会更小,困难样本的损失会变大。例如gamma为2,对于正类样本而言,预测结果为0.95肯定是简单样本,所以(1-0.95)的gamma次方就会很小,这时损失函数值就变得更小。而预测概率为0.3的样本其损失相对很大。对于负类样本而言同样,预测0.1的结果应当远比预测0.7的样本损失值要小得多。对于预测概率为0.5时,损失只减少了0.25倍,所以更加关注于这种难以区分的样本。这样减少了简单样本的影响,大量预测概率很小的样本叠加起来后的效应才可能比较有效。

              

同时,为了平衡正负样本比例,加入因子alpha。alpha代表正负样本的比例,若alpha=0.3,代表此时正样本要比负样本占比小,负样本例更易分因此得到更大的权重(1-0.3)。

                           

 

3.Pytorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
 
class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.
            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
        The losses are averaged across observations for each minibatch.
        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.
    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average
 
    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)
 
        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)
 
 
        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]
 
        probs = (P*class_mask).sum(1).view(-1,1)
 
        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)
 
        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)
 
 
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

参考资料:https://zhuanlan.zhihu.com/p/28527749

                 https://www.cnblogs.com/king-lps/p/9497836.html

你可能感兴趣的:(Focal loss的Pytorch实现)