图像分割中涉及的损失函数(主要用来处理样本不平衡)

图像分割中的损失函数

  • 前言
  • 解决办法
  • 损失函数
    • 1. log loss损失函数
    • 2. WBE loss
    • 3. Focal loss
      • 应用场景
      • 思想
      • 公式
    • Dice loss
      • Dice 系数
      • Dice 差异函数

前言

图像分割中的loss函数继承了深度学习模型中一般损失函数的所有特点,但是又有其自身的特点,即需要处理类别不平衡的问题,在进行图像分割中,往往图像中成为背景的像素值占绝大多数,而前景目标的像素只有很少一部分。
注:以下链接详细介绍了深度学习模型中的一般损失函数。

参考链接:https://blog.csdn.net/weixin_38410551/article/details/104973011

如图所示:
图像分割中涉及的损失函数(主要用来处理样本不平衡)_第1张图片

图像分割中涉及的损失函数(主要用来处理样本不平衡)_第2张图片
注:车道线只是占很少一部分,大部分为背景。

图像分割处理的时候,经常要遇到这样的样本不均衡问题

解决办法

主流的解决办法,都是通过减少样本中样本数较多的类别损失函数权重,增加样本中样本数较少的类别损失函数的权重。这样,预测样本较少的类别,损失函数下降的更快,而预测样本较多的类别,损失函数下降得慢。
注:防止过拟合,也采取了类似的办法,通过增加权重参数的二次项,在优化损失函数的过程中,来降低权重的大小,从而达到防止过拟合的目的。(过拟合:某些权重参数过大

损失函数

1. log loss损失函数

参考链接:https://blog.csdn.net/weixin_38410551/article/details/104973011

2. WBE loss

思想:样本数目较多的类别,要减小权重,样本数目较少的类别,要增加权重。
公式: W C E = − 1 N ∑ n = 1 N w r n ​ l o g ( p n ) + ( 1 − r n ​ ) l o g ( 1 − p n ​ ) WCE=−\frac{1}{N}\sum_{n = 1}^{N}wr_{n}​log(p_{n})+(1−r_{n}​)log(1−p_{n}​) WCE=N1n=1Nwrnlog(pn)+(1rn)log(1pn)
w = − N − ∑ n p n ∑ n p n w = -\frac{N - \sum_{n}p_{n}}{\sum_{n}p_{n}} w=npnNnpn
其中, w w w为权重。
缺点:需要人为去调整权重。

3. Focal loss

应用场景

应用于目标检测的二分类问题。

思想

与WB loss思想相同,但是这里的参数,可以自动化调节。

公式

− 1 N ∑ i = 1 N ( α y i ( 1 − p i ) γ ​ l o g ( p i ) + ( 1 − α ) ( 1 − y i ) p i γ l o g ( 1 − p i ​ ) ) −\frac{1}{N}\sum_{i = 1}^{N}(\alpha y_{i}(1 - p_{i})^{\gamma}​log(p_{i})+(1−\alpha)(1-y_{i})p_{i}^{\gamma}log(1−p_{i}​)) N1i=1N(αyi(1pi)γlog(pi)+(1α)(1yi)piγlog(1pi))
其基本思想就是,对于类别极度不均衡的情况下,网络如果在log loss下会倾向于只预测负样本,并且负样本的预测概率pipi p_ipi​也会非常的高,回传的梯度也很大。但是如果添加(1−pi)γ(1−pi)γ (1-p_i)^{\gamma}(1−pi​)γ则会使预测概率大的样本得到的loss变小,而预测概率小的样本,loss变得大,从而加强对正样本的关注度。
可以改善目标不均衡的现象,对此情况比 binary_crossentropy 要好很多。

Dice loss

Dice 系数

2 A ⋂ B A ⋃ B \frac{2A\bigcap B}{A\bigcup B} AB2AB
其中,A为样本标签值,B为预测值,是一种集合相似度度量函数,通常用来衡量两个集合的相似度。(取值范围[0,1])。
分子乘以2, 是因为分母存在重复计算A和B的重合元素。

Dice 差异函数

很简单,1减去Dice系数,就是差异函数。Dice系数和Dice差异函数是对同一问题的两种表述方式。 1 − 2 X ⋂ Y X ⋃ Y 1 - \frac{2X\bigcap Y}{X\bigcup Y} 1XY2XY

注:dice loss 比较适用于样本极度不均的情况,一般的情况下,使用 dice loss 会对反向传播造成不利的影响,容易使训练变得不稳定.
示例代码:

class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """
    def __init__(self, smooth=1, p=2, reduction='mean'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth #避免分子为0
        self.p = p #平方值
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)
        #num:分子
        num = 2*torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        #den 分母
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
        #dice loss
        loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input
    Args:
        weight: An array of shape [num_classes,]
        ignore_index: class index to ignore
        predict: A tensor of shape [N, C, *]
        target: A tensor of same shape with predict
        other args pass to BinaryDiceLoss
    Return:
        same as BinaryDiceLoss
    """
    def __init__(self, weight=None, ignore_index=None, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weight = weight
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        #predict和target两者形状相同,否则报错
        assert predict.shape == target.shape, 'predict & target shape do not match'
        #对BinaryDiceLoss进行实例化
        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        predict = F.softmax(predict, dim=1)

        for i in range(target.shape[1]):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])
                if self.weight is not None:
                    assert self.weight.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss

        return total_loss/target.shape[1]   #求取了一个均值

参考链接:https://blog.csdn.net/m0_37477175/article/details/83004746
https://blog.csdn.net/JMU_Ma/article/details/97533768?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task
https://blog.csdn.net/m0_37477175/article/details/83004746#Dice_loss_70

你可能感兴趣的:(pytorch,深度学习)