Dice Loss是由 Dice 系数而得名的,Dice系数是一种用于评估两个样本相似性的度量函数,其值越大意味着这两个样本越相似,Dice系数的数学表达式如下:
Dice = 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ \text { Dice }=\frac{2|X \cap Y|}{|X|+|Y|} Dice =∣X∣+∣Y∣2∣X∩Y∣
其中, ∣ X ∩ Y ∣ |X \cap Y| ∣X∩Y∣ 表示 X 和 Y 之间交集元素的个数, ∣ X ∣ |X| ∣X∣ 和 ∣ Y ∣ |Y| ∣Y∣ 分别表示 X,Y 中元素的个数。Dice Loss 表达式如下:
DiceLoss = 1 − Dice = 1 − 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ \text { DiceLoss }=1-\text { Dice }=1-\frac{2|X \cap Y|}{|X|+|Y|} DiceLoss =1− Dice =1−∣X∣+∣Y∣2∣X∩Y∣
Dice Loss常用于语义分割问题中,计算公式不变,但是变量含义有所改变。 X \mathrm{X} X 表示真实分割图像的像素标签, Y \mathrm{Y} Y 表示模型预测分割图像的像素类别。 Y \mathrm{Y} Y 只有 0,1 两个值,0表示像素点不是目标类,1表示像素点是目标类。 ∣ X ∩ Y ∣ |X \cap Y| ∣X∩Y∣ 近似为预测图像的像素与真实标签图像的像素之间的点乘,并将点乘结果相加, ∣ X ∣ |X| ∣X∣ 和 ∣ Y ∣ |Y| ∣Y∣ 分别近似为它们各自对应图像中的像素相加。故有公式:
DiceLoss = 1 − 2 ∑ i = 1 N y i y ^ i ∑ i = 1 N y i + ∑ i = 1 N y i ^ \text { DiceLoss }=1-\frac{2 \sum_{i=1}^N y_i \hat{y}_i}{\sum_{i=1}^N y_i+\sum_{i=1}^N \hat{y_i}} DiceLoss =1−∑i=1Nyi+∑i=1Nyi^2∑i=1Nyiy^i
注意,dice loss通常是不计算背景类的。
对于多分类问题,对 label 进行 one hot 编码,生成多个 label 图,每个类别对应一个二分类label图。通过计算每个类别的 Dice Loss 损失,最后再求均值即得到多分类的 Dice Loss 损失。
假设有两个集合 A A A 和 B B B , Dice系数定义为:
Dice ( A , B ) = 2 ∣ A ∩ B ∣ ∣ A ∣ + ∣ B ∣ \operatorname{Dice}(A, B) =\frac{2|A \cap B|}{|A|+|B|} Dice(A,B)=∣A∣+∣B∣2∣A∩B∣
A ∩ B A \cap B A∩B 表示预测结果与真实标签的交集,在二分类问题中等于正确预测为正类的数量 TP 。而 F P FP FP 表示预测为正类但实际上是负类的数量 (属于A,但不属于B) , F N FN FN 表示预测为负类但实际上是正类的数量 (属于B,但不属于A) ,故又有
2 ∗ T P + F P + F N = ∣ A ∣ + ∣ B ∣ 2* \mathrm{TP}+\mathrm{FP}+\mathrm{FN} = |A|+|B| 2∗TP+FP+FN=∣A∣+∣B∣
故有,
D i c e = 2 ∗ T P 2 ∗ T P + F P + F N Dice = \frac{2 * TP}{2 * TP + FP + FN} Dice=2∗TP+FP+FN2∗TP
,该公式正好等于 F1-score。
假设模型输出的预测值如下
标签 label 如下(0 即对应背景,表示不属于某一类,1 表示属于某一类):
计算类别1的dice:
∣ X ∩ Y ∣ = [ 0.5322 0.4932 0.1764 0.3107 0.5297 0.1604 0.3841 0.3537 0.3574 0.3323 0.8301 0.6436 ] ⋆ [ 0 0 0 0 0 0 1 1 1 1 1 1 ] = [ 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.3841 0.3537 0.3574 0.3323 0.8301 0.6436 ] → 2.9012 (求和) \begin{aligned} |\mathrm{X} \cap \mathrm{Y}| & =\left[\begin{array}{lll} 0.5322 & 0.4932 & 0.1764 \\ 0.3107 & 0.5297 & 0.1604 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \star\left[\begin{array}{lll} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{array}\right] \\ \\ & =\left[\begin{array}{lll} 0.0000 & 0.0000 & 0.0000 \\ 0.0000 & 0.0000 & 0.0000 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \rightarrow 2.9012 \text { (求和) } \end{aligned} ∣X∩Y∣= 0.53220.31070.38410.33230.49320.52970.35370.83010.17640.16040.35740.6436 ⋆ 001100110011 = 0.00000.00000.38410.33230.00000.00000.35370.83010.00000.00000.35740.6436 →2.9012 (求和)
∣ X ∣ = [ 0.5322 0.4932 0.1764 0.3107 0.5297 0.1604 0.3841 0.3537 0.3574 0.3323 0.8301 0.6436 ] → 5.1038 |\mathrm{X}|=\left[\begin{array}{lll} 0.5322 & 0.4932 & 0.1764 \\ 0.3107 & 0.5297 & 0.1604 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \rightarrow 5.1038 ∣X∣= 0.53220.31070.38410.33230.49320.52970.35370.83010.17640.16040.35740.6436 →5.1038
∣ Y ∣ = [ 0 0 0 0 0 0 1 1 1 1 1 1 ] → 6 |Y|=\left[\begin{array}{lll} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{array}\right] \rightarrow 6 ∣Y∣= 001100110011 →6
所以 Dice 系数为
D = 2 ∗ ∣ X ∩ Y ∣ + 1 ∣ X ∣ + ∣ Y ∣ + 1 = 2 ∗ 2.9012 + 1 5.1038 + 6 + 1 = 0.5620 \mathrm{D}=\frac{2 *|\mathrm{X} \cap \mathrm{Y}|+1}{|\mathrm{X}|+|\mathrm{Y}|+1}=\frac{2 * 2.9012+1}{5.1038+6+1}=0.5620 D=∣X∣+∣Y∣+12∗∣X∩Y∣+1=5.1038+6+12∗2.9012+1=0.5620
所以 Dice 损失 L = 1 − D = 0.4380 \mathrm{L}=1-\mathrm{D}=0.4380 L=1−D=0.4380 。
Dice Loss 可以缓解样本中前景背景(面积)不平衡带来的消极影响,前景背景不平衡也就是说图像中大部分区域是不包含目标的,只有一小部分区域包含目标。Dice Loss训练更关注对前景区域的挖掘,即保证有较低的FN,但会存在损失饱和问题,而CE Loss是平等地计算每个像素点的损失,当前点的损失只和当前预测值与真实标签值的距离有关,这会导致一些问题(见Focal Loss)。因此单独使用Dice Loss往往并不能取得较好的结果,需要进行组合使用,比如Dice Loss+CE Loss或者Dice Loss+Focal Loss等。
def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
temp_target = target.view(n, -1, ct)
#--------------------------------------------#
# 计算dice loss
#--------------------------------------------#
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
fp = torch.sum(temp_inputs , axis=[0,1]) - tp
fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
dice_loss = 1 - torch.mean(score)
return dice_loss