Dice Loss

导读

​ Dice Loss是由 Dice 系数而得名的,Dice系数是一种用于评估两个样本相似性的度量函数,其值越大意味着这两个样本越相似,Dice系数的数学表达式如下:
 Dice  = 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ \text { Dice }=\frac{2|X \cap Y|}{|X|+|Y|}  Dice =X+Y2∣XY

其中, ∣ X ∩ Y ∣ |X \cap Y| XY 表示 X 和 Y 之间交集元素的个数, ∣ X ∣ |X| X ∣ Y ∣ |Y| Y 分别表示 X,Y 中元素的个数。Dice Loss 表达式如下:

 DiceLoss  = 1 −  Dice  = 1 − 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ \text { DiceLoss }=1-\text { Dice }=1-\frac{2|X \cap Y|}{|X|+|Y|}  DiceLoss =1 Dice =1X+Y2∣XY

语义分割下的dice loss

​ Dice Loss常用于语义分割问题中,计算公式不变,但是变量含义有所改变。 X \mathrm{X} X 表示真实分割图像的像素标签, Y \mathrm{Y} Y 表示模型预测分割图像的像素类别。 Y \mathrm{Y} Y 只有 0,1 两个值,0表示像素点不是目标类,1表示像素点是目标类。 ∣ X ∩ Y ∣ |X \cap Y| XY 近似为预测图像的像素与真实标签图像的像素之间的点乘,并将点乘结果相加, ∣ X ∣ |X| X ∣ Y ∣ |Y| Y​ 分别近似为它们各自对应图像中的像素相加。故有公式:
 DiceLoss  = 1 − 2 ∑ i = 1 N y i y ^ i ∑ i = 1 N y i + ∑ i = 1 N y i ^ \text { DiceLoss }=1-\frac{2 \sum_{i=1}^N y_i \hat{y}_i}{\sum_{i=1}^N y_i+\sum_{i=1}^N \hat{y_i}}  DiceLoss =1i=1Nyi+i=1Nyi^2i=1Nyiy^i

注意,dice loss通常是不计算背景类的。

​ 对于多分类问题,对 label 进行 one hot 编码,生成多个 label 图,每个类别对应一个二分类label图。通过计算每个类别的 Dice Loss 损失,最后再求均值即得到多分类的 Dice Loss 损失。

等价F1-score

​ 假设有两个集合 A A A B B B , Dice系数定义为:
Dice ⁡ ( A , B ) = 2 ∣ A ∩ B ∣ ∣ A ∣ + ∣ B ∣ \operatorname{Dice}(A, B) =\frac{2|A \cap B|}{|A|+|B|} Dice(A,B)=A+B2∣AB
A ∩ B A \cap B AB 表示预测结果与真实标签的交集,在二分类问题中等于正确预测为正类的数量 TP 。而 F P FP FP 表示预测为正类但实际上是负类的数量 (属于A,但不属于B) , F N FN FN 表示预测为负类但实际上是正类的数量 (属于B,但不属于A) ,故又有
2 ∗ T P + F P + F N = ∣ A ∣ + ∣ B ∣ 2* \mathrm{TP}+\mathrm{FP}+\mathrm{FN} = |A|+|B| 2TP+FP+FN=A+B
故有,
D i c e = 2 ∗ T P 2 ∗ T P + F P + F N Dice = \frac{2 * TP}{2 * TP + FP + FN} Dice=2TP+FP+FN2TP
,该公式正好等于 F1-score。

二分类例子

​ 假设模型输出的预测值如下

Dice Loss_第1张图片标签 label 如下(0 即对应背景,表示不属于某一类,1 表示属于某一类):
Dice Loss_第2张图片计算类别1的dice:
∣ X ∩ Y ∣ = [ 0.5322 0.4932 0.1764 0.3107 0.5297 0.1604 0.3841 0.3537 0.3574 0.3323 0.8301 0.6436 ] ⋆ [ 0 0 0 0 0 0 1 1 1 1 1 1 ] = [ 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.3841 0.3537 0.3574 0.3323 0.8301 0.6436 ] → 2.9012  (求和)  \begin{aligned} |\mathrm{X} \cap \mathrm{Y}| & =\left[\begin{array}{lll} 0.5322 & 0.4932 & 0.1764 \\ 0.3107 & 0.5297 & 0.1604 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \star\left[\begin{array}{lll} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{array}\right] \\ \\ & =\left[\begin{array}{lll} 0.0000 & 0.0000 & 0.0000 \\ 0.0000 & 0.0000 & 0.0000 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \rightarrow 2.9012 \text { (求和) } \end{aligned} XY= 0.53220.31070.38410.33230.49320.52970.35370.83010.17640.16040.35740.6436 001100110011 = 0.00000.00000.38410.33230.00000.00000.35370.83010.00000.00000.35740.6436 2.9012 (求和

∣ X ∣ = [ 0.5322 0.4932 0.1764 0.3107 0.5297 0.1604 0.3841 0.3537 0.3574 0.3323 0.8301 0.6436 ] → 5.1038 |\mathrm{X}|=\left[\begin{array}{lll} 0.5322 & 0.4932 & 0.1764 \\ 0.3107 & 0.5297 & 0.1604 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \rightarrow 5.1038 X= 0.53220.31070.38410.33230.49320.52970.35370.83010.17640.16040.35740.6436 5.1038

∣ Y ∣ = [ 0 0 0 0 0 0 1 1 1 1 1 1 ] → 6 |Y|=\left[\begin{array}{lll} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{array}\right] \rightarrow 6 Y= 001100110011 6

所以 Dice 系数为
D = 2 ∗ ∣ X ∩ Y ∣ + 1 ∣ X ∣ + ∣ Y ∣ + 1 = 2 ∗ 2.9012 + 1 5.1038 + 6 + 1 = 0.5620 \mathrm{D}=\frac{2 *|\mathrm{X} \cap \mathrm{Y}|+1}{|\mathrm{X}|+|\mathrm{Y}|+1}=\frac{2 * 2.9012+1}{5.1038+6+1}=0.5620 D=X+Y+12XY+1=5.1038+6+122.9012+1=0.5620
所以 Dice 损失 L = 1 − D = 0.4380 \mathrm{L}=1-\mathrm{D}=0.4380 L=1D=0.4380

优点

​ Dice Loss 可以缓解样本中前景背景(面积)不平衡带来的消极影响,前景背景不平衡也就是说图像中大部分区域是不包含目标的,只有一小部分区域包含目标。Dice Loss训练更关注对前景区域的挖掘,即保证有较低的FN,但会存在损失饱和问题,而CE Loss是平等地计算每个像素点的损失,当前点的损失只和当前预测值与真实标签值的距离有关,这会导致一些问题(见Focal Loss)。因此单独使用Dice Loss往往并不能取得较好的结果,需要进行组合使用,比如Dice Loss+CE Loss或者Dice Loss+Focal Loss等。

代码

def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)
 
    #--------------------------------------------#
    #   计算dice loss
    #--------------------------------------------#
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp
 
    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss

你可能感兴趣的:(python,cv)