损失函数focal loss深度理解与简单实现

本文主要从二值交叉熵损失函数出发,通过代码实现的方式,去更好地理解Focal Loss对于数据不平衡问题、难易样本问题损失是如何权衡的。

1.  首先我们给出比较官方一些的代码,具体就是mmdet中的py_sigmoid_focal_loss函数。

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    """
        PyTorch version of `Focal Loss `_.

    """
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

2.  根据理解,自己写得更简单直观的代码。

class BCEFocalLoss(nn.Module):
    def __init__(self,alpha=0.25,gamma=2):
        super(BCEFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self,logits,label):
        label = label.unsqueeze(1)  #  size(N, 1)
        assert label.size() == logits.size()
        probs = torch.sigmoid(logits) 
        pos_loss = -label*self.alpha*probs.log()*(1-probs)**self.gamma
        neg_loss = -(1-label)*(1-self.alpha)*(1-probs).log()*probs**self.gamma
        loss = (pos_loss + neg_loss).mean()
        return loss

3.  简单理论理解

比如,我们在做猫狗分类的任务,其中猫咪的图片有1000张,狗子的图片有300张,常见的二值交叉熵损失函数会倾向于学习到更多关于猫咪的知识,与此同时,会学到很少关于狗子的知识,这显然会让我们的分类器在识别狗子时容易失误,可以认为模型缺乏对狗子的理解。

因此,对于猫咪的图片其预测概率会更加置信,接近于1,此时focal loss的调制因子就起到了一种约束作用,其中(1-probs)^{\gamma} 会更加接近于0,而对于分类不准确狗子的样本,损失基本没有改变,整体而言,相当于增加了分类不准确样本在损失函数中的权重。

上述的描述是从样本量方面解释了focal loss对于难易样本的约束,宏观理解就是样本量大的通常更加容易学习,样本量少的损失通常更加容易被样本量大的损失盖住,降低其损失影响。

当然,不管是样本多的类,还是样本少的类,都是存在难易样本的,因此,focal loss对于这种情况也是发挥作用的。

4.  参数分析

(1)其中gamma作用用于调节难易样本对于总loss的权重,其值越大,调整因子的影响也越大,这里最佳取值在实验中设置为2。

(2)其中平衡因子alpha,主要用来平衡正负样本比例不均的,从理论来讲,对于正样本,比如狗子图片,其数量相对来说更少,我们应该采用一个大于0.5的alpha值,来平衡类别之间的权重,但实际实验中,论文采用了更加合适的取值0.25,这主要可能是因为gamma参数的影响占据了更大的作用,alpha在这里起到了一个额外的辅助微调整作用,避免了整体的矫枉过正或者力有不逮的情况。

你可能感兴趣的:(深度学习,图像识别,损失函数,python)