由于 Dice系数是图像分割中常用的指标,而在Pytoch中没有官方的实现,下面通过自己的想法并结合网上的一些参考进行详细实现。
先来看一个我在网上看到的一个版本。
def diceCoeff(pred, gt, smooth=1, activation='sigmoid'):
r""" computational formula:
dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
"""
if activation is None or activation == "none":
activation_fn = lambda x: x
elif activation == "sigmoid":
activation_fn = nn.Sigmoid()
elif activation == "softmax2d":
activation_fn = nn.Softmax2d()
else:
raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")
pred = activation_fn(pred)
N = gt.size(0)
pred_flat = pred.view(N, -1)
gt_flat = gt.view(N, -1)
intersection = (pred_flat * gt_flat).sum(1)
unionset = pred_flat.sum(1) + gt_flat.sum(1)
loss = 2 * (intersection + smooth) / (unionset + smooth)
return loss.sum() / N
整体思路就是运用dice的计算公式 (2 * A∩B) / (A∪B)。下面来分析一下可能存在的问题:
smooth参数是用来防止分母除0的,但是如果smooth=1的话,会使得dice的计算结果略微偏高,看下面的测试代码。
第一种情况:预测和标签完全一样
# shape = torch.Size([1, 3, 4, 4])
'''
1 0 0= bladder
0 1 0 = tumor
0 0 1= background
'''
pred = torch.Tensor([[
[[0, 1, 1, 0],
[1, 0, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 1, 1, 0],
[1, 0, 0, 1]]]])
gt = torch.Tensor([[
[[0, 1, 1, 0],
[1, 0, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 1, 1, 0],
[1, 0, 0, 1]]]])
dice_baldder1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1, activation=None)
dice_baldder2 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1e-5, activation=None)
print('smooth=1 : dice={:.4}'.format(dice_baldder1.item()))
print('smooth=1e-5 : dice={:.4}'.format(dice_baldder2.item()))
# 输出结果
smooth=1 : dice=1.050
smooth=1e-5 : dice=1.0
我们最后预测的是一个3分类的分割图,第一类是baldder, 第二类是tumor, 第三类是背景。我们先假设bladder的预测pred和gt一样,计算bladder的dice值,发现当smooth=1的时候,dice偏高, 而smooth=1e-5时dice比较合理。
解决办法:我想这里应该更改代码的实现方式,用下面的计算公式替换之前的,因为之前加smooth的位置有问题。
# loss = 2 * (intersection + smooth) / (unionset + smooth) # 之前的
loss = (2 * intersection + smooth) / (unionset + smooth)
替换后的dice如下:
def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):
r""" computational formula:
dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
"""
if activation is None or activation == "none":
activation_fn = lambda x: x
elif activation == "sigmoid":
activation_fn = nn.Sigmoid()
elif activation == "softmax2d":
activation_fn = nn.Softmax2d()
else:
raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")
pred = activation_fn(pred)
N = gt.size(0)
pred_flat = pred.view(N, -1)
gt_flat = gt.view(N, -1)
intersection = (pred_flat * gt_flat).sum(1)
unionset = pred_flat.sum(1) + gt_flat.sum(1)
loss = (2 * intersection + smooth) / (unionset + smooth)
return loss.sum() / N
上面用到的测试数据进行验证结果如下:dice计算正确
# smooth=1 : dice=1.0
# smooth=1e-5 : dice=1.0
第二种情况:预测的结果不在标签中
如下面的代码,我们假设预测的pred中有一部分bladder,但gt中没有bladder,看计算出的dice值如何。
'''
1 0 0= bladder
0 1 0 = tumor
0 0 1= background
'''
pred = torch.Tensor([[
[[0, 1, 1, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 0, 0, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]]]])
gt = torch.Tensor([[
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]]]])
dice_baldder1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1, activation=None)
dice_baldder2 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1e-5, activation=None)
print('smooth=1 : dice={:.4}'.format(dice_baldder1.item()))
print('smooth=1e-5 : dice={:.4}'.format(dice_baldder2.item()))
# 输出结果
smooth=1 : dice=0.3333
smooth=1e-5 : dice=5e-06
从结果可以看到,smooth=1时的dice值为0.3333;而 smooth=1e-5时的dice值接近于0,较为合理。
dice的另一种计算方式:这里参考肾脏肿瘤挑战赛提供的dice计算方法。
def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):
r""" computational formula:
dice = (2 * tp) / (2 * tp + fp + fn)
"""
if activation is None or activation == "none":
activation_fn = lambda x: x
elif activation == "sigmoid":
activation_fn = nn.Sigmoid()
elif activation == "softmax2d":
activation_fn = nn.Softmax2d()
else:
raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")
pred = activation_fn(pred)
N = gt.size(0)
pred_flat = pred.view(N, -1)
gt_flat = gt.view(N, -1)
tp = torch.sum(gt_flat * pred_flat, dim=1)
fp = torch.sum(pred_flat, dim=1) - tp
fn = torch.sum(gt_flat, dim=1) - tp
loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)
return loss.sum() / N
整理代码
def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):
r""" computational formula:
dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
"""
if activation is None or activation == "none":
activation_fn = lambda x: x
elif activation == "sigmoid":
activation_fn = nn.Sigmoid()
elif activation == "softmax2d":
activation_fn = nn.Softmax2d()
else:
raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")
pred = activation_fn(pred)
N = gt.size(0)
pred_flat = pred.view(N, -1)
gt_flat = gt.view(N, -1)
intersection = (pred_flat * gt_flat).sum(1)
unionset = pred_flat.sum(1) + gt_flat.sum(1)
loss = (2 * intersection + smooth) / (unionset + smooth)
return loss.sum() / N
def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):
r""" computational formula:
dice = (2 * tp) / (2 * tp + fp + fn)
"""
if activation is None or activation == "none":
activation_fn = lambda x: x
elif activation == "sigmoid":
activation_fn = nn.Sigmoid()
elif activation == "softmax2d":
activation_fn = nn.Softmax2d()
else:
raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")
pred = activation_fn(pred)
N = gt.size(0)
pred_flat = pred.view(N, -1)
gt_flat = gt.view(N, -1)
tp = torch.sum(gt_flat * pred_flat, dim=1)
fp = torch.sum(pred_flat, dim=1) - tp
fn = torch.sum(gt_flat, dim=1) - tp
loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)
return loss.sum() / N
# v2的另一种代码写法
def diceCoeffv3(pred, gt, eps=1e-5, activation='sigmoid'):
r""" computational formula:
dice = (2 * tp) / (2 * tp + fp + fn)
"""
if activation is None or activation == "none":
activation_fn = lambda x: x
elif activation == "sigmoid":
activation_fn = nn.Sigmoid()
elif activation == "softmax2d":
activation_fn = nn.Softmax2d()
else:
raise NotImplementedError("Activation implemented for sigmoid and softmax2d")
pred = activation_fn(pred)
N = gt.size(0)
pred_flat = pred.view(N, -1)
gt_flat = gt.view(N, -1)
tp = torch.sum((pred_flat != 0) * (gt_flat != 0), dim=1)
fp = torch.sum((pred_flat != 0) * (gt_flat == 0), dim=1)
fn = torch.sum((pred_flat == 0) * (gt_flat != 0), dim=1)
# 转为float,以防long类型之间相除结果为0
loss = (2 * tp + eps).float() / (2 * tp + fp + fn + eps).float()
return loss.sum() / N
class SoftDiceLoss(nn.Module):
__name__ = 'dice_loss'
def __init__(self, activation='sigmoid'):
super(SoftDiceLoss, self).__init__()
self.activation = activation
def forward(self, y_pr, y_gt):
return 1 - diceCoeffv2(y_pr, y_gt, activation=self.activation)
代码测试:
if __name__ == '__main__':
# shape = torch.Size([2, 3, 4, 4])
# 模拟batch_size = 2
'''
1 0 0= bladder
0 1 0 = tumor
0 0 1= background
'''
pred = torch.Tensor([[
[[0, 1, 0, 0],
[1, 0, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]],
[[1, 0, 1, 1],
[0, 1, 1, 0],
[0, 0, 0, 0],
[1, 0, 0, 1]]],
[
[[0, 1, 0, 0],
[1, 0, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]],
[[1, 0, 1, 1],
[0, 1, 1, 0],
[0, 0, 0, 0],
[1, 0, 0, 1]]]
])
gt = torch.Tensor([[
[[0, 1, 1, 0],
[1, 0, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]],
[[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 0, 0, 0],
[1, 0, 0, 1]]],
[
[[0, 1, 1, 0],
[1, 0, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 0, 0, 0]],
[[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 0, 0, 0],
[1, 0, 0, 1]]]
])
dice1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], activation=None)
dice2 = diceCoeffv2(pred[:, 0:1, :], gt[:, 0:1, :], activation=None)
dice3 = diceCoeffv3(pred[:, 0:1, :], gt[:, 0:1, :], activation=None)
print(dice1, dice2, dice3)
# 输出
tensor(0.9333) tensor(0.9333) tensor(0.9333)
总结:上面是这几天对dice以及dice loss的一些思考和实现。
2020/6/8更新:SoftDiceLoss的计算方式修改。在实际项目种训练发现之前的loss计算方式不够准确,现在按类别计算dice,求平均之后得到loss。
class SoftDiceLossV2(_Loss):
__name__ = 'dice_loss'
def __init__(self, num_classes, activation='sigmoid', reduction='mean'):
super(SoftDiceLossV2, self).__init__()
self.activation = activation
self.num_classes = num_classes
def forward(self, y_pred, y_true):
class_dice = []
for i in range(1, self.num_classes):
class_dice.append(diceCoeff(y_pred[:, i:i + 1, :], y_true[:, i:i + 1, :], activation=self.activation))
mean_dice = sum(class_dice) / len(class_dice)
return 1 - mean_dice