Focal Loss的简述与实现

文章目录

    • 交叉熵损失函数
    • 样本不均衡问题
    • Focal Loss
    • Focal Loss的代码实现

交叉熵损失函数

L o s s = L ( y , p ^ ) = − y l o g ( p ^ ) − ( 1 − y ) l o g ( 1 − p ^ ) Loss = L(y, \hat{p})=-ylog(\hat{p})-(1-y)log(1-\hat{p}) Loss=L(y,p^)=ylog(p^)(1y)log(1p^)

其中 p ^ \hat{p} p^为预测概率大小。此处的交叉熵公式只考虑了二分类的情况。

L c e ( y , p ^ ) = { − l o g ( p ^ ) , if  y =1 − l o g ( 1 − p ^ ) if  y = 0 L_{ce}(y, \hat{p}) = \begin{cases} -log(\hat{p}), & \text{if } y \text{=1} \\ -log(1-\hat{p}) & \text {if } y {=0} \end{cases} Lce(y,p^)={log(p^),log(1p^)if y=1if y=0

y y y为label,在二分类中对应0,1。

样本不均衡问题

对于所有样本,二分类问题的损失函数可以写为:

L = 1 N ( ∑ y i = 1 m − l o g ( p ^ ) + ∑ y i = 0 n − l o g ( 1 − p ^ ) ) L=\frac{1}{N}(\sum_{y_i =1}^m -log(\hat{p})+\sum_{y_i=0}^{n}-log(1-\hat{p})) L=N1(yi=1mlog(p^)+yi=0nlog(1p^))

其中m为正样本个数,n为负样本个数,N为样本总数,m+n=N。

当样本分布失衡时,在损失函数L的分布也会发生倾斜,如m<

Focal Loss

下式是focal loss的公式,其是用于对付样本不均衡的一个方法。

L f l = − ( 1 − p t ) γ l o g ( p t ) L_{fl}=-(1-p_t)^\gamma log(p_t) Lfl=(1pt)γlog(pt)

γ \gamma γ是可调节因子。

下式是转化过的交叉熵公式。

L c e = − l o g ( p t ) L_{ce} = -log({p_t}) Lce=log(pt)

上两式子中的 p t p_t pt如下合成。

p t = { p ^ if y=1 1 − p ^ if y=0 p_t=\begin{cases} \hat{p} & \text{if } \text{y=1} \\ 1-\hat{p} & \text{if } \text{y=0} \end{cases} pt={p^1p^if y=1if y=0

p t p_t pt越大,该样本的分类越准确。

对比交叉熵和focal loss的公式可以看出两者只有 ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ的区别。
( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ对于分类准确的样本影响较小,因为分类越准确, ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ越趋于0,而分类越不准确, ( 1 − p t ) γ (1-p_t)^\gamma (1pt)γ则趋于1。故相比交叉熵损失,focal loss对于分类不准确的样本,损失没有改变,对于分类准确的样本,损失会变小。 整体而言,相当于增加了分类不准确样本在损失函数中的权重。
相信很多人会在这里有一个疑问,样本难易分类角度怎么能够解决样本非平衡的问题,直觉上来讲样本非平衡造成的问题就是样本数少的类别分类难度较高。因此从样本难易分类角度出发,使得loss聚焦于难分样本,解决了样本少的类别分类准确率不高的问题,当然难分样本不限于样本少的类别,也就是focal loss不仅仅解决了样本非平衡的问题,同样有助于模型的整体性能提高。

Focal Loss的代码实现

此处代码来自于UFLDv2模型。

class SoftmaxFocalLoss(nn.Module):
    def __init__(self, gamma, ignore_lb=255, soft_loss = True, *args, **kwargs):
        super(SoftmaxFocalLoss, self).__init__()
        #伽马因子
        self.gamma = gamma
        #忽略标签值
        self.ignore_lb = ignore_lb
        #是否使用软损失
        self.soft_loss = soft_loss
        if not self.soft_loss:
            #负对数自然损失对象
            self.nll = nn.NLLLoss(ignore_index=ignore_lb)

    def forward(self, logits, labels):
        #获取p_t
        scores = F.softmax(logits, dim=1)
        #focal loss中的(1-p_t)^\gamma
        factor = torch.pow(1.-scores, self.gamma)
        #log_softmax 函数首先对输入进行 softmax 操作,然后再取对数。
        #这种方式在数值计算上更加稳定,能够避免计算指数函数时出现的数值溢出问题。比单纯的log更好。
        log_score = F.log_softmax(logits, dim=1)
        #在 Focal Loss 中,负号通常是作为损失函数的一部分而被省略了。
        #因为在实际使用中,我们更关心最小化损失值,而不是其符号。
        log_score = factor * log_score
        if self.soft_loss:
            #自定义软负对数似然损失,通过处理目标值和进行加权融合,以及对有效样本进行计数和相应的加权求和,来得到模型预测的损失值。
            loss = soft_nll(log_score, labels, ignore_index = self.ignore_lb)
        else:
            #负对数似然损失通常用于多分类问题中,特别是在神经网络输出的概率分布上。
            #通过设置 ignore_index=ignore_lb,这意味着在计算损失时会忽略标签值为 ignore_lb 的样本,不将其纳入损失的计算范围内。
            #这个功能在处理带有遮蔽值或者填充值的数据时非常有用。
            #这样做可以更好地处理不均衡的数据集或者特殊标记的样本,同时保证模型的训练稳定性和准确性。
            loss = self.nll(log_score, labels)

        # import pdb; pdb.set_trace()
        return loss

你可能感兴趣的:(人工智能,机器学习,深度学习)