【Pytorch】 Dice系数与Dice Loss损失函数实现

由于 Dice系数是图像分割中常用的指标,而在Pytoch中没有官方的实现,下面通过自己的想法并结合网上的一些参考进行详细实现。

先来看一个我在网上看到的一个版本。

def diceCoeff(pred, gt, smooth=1, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    intersection = (pred_flat * gt_flat).sum(1)
    unionset = pred_flat.sum(1) + gt_flat.sum(1)
    loss = 2 * (intersection + smooth) / (unionset + smooth)

    return loss.sum() / N

整体思路就是运用dice的计算公式 (2 * A∩B) / (A∪B)。下面来分析一下可能存在的问题:

smooth参数是用来防止分母除0的,但是如果smooth=1的话,会使得dice的计算结果略微偏高,看下面的测试代码。

第一种情况:预测和标签完全一样

# shape = torch.Size([1, 3, 4, 4])
'''
1 0 0= bladder
0 1 0 = tumor
0 0 1= background 
'''
pred = torch.Tensor([[
        [[0, 1, 1, 0],
         [1, 0, 0, 1],
         [1, 0, 0, 1],
         [0, 1, 1, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[1, 0, 0, 1],
         [0, 1, 1, 0],
         [0, 1, 1, 0],
         [1, 0, 0, 1]]]])
    
gt = torch.Tensor([[
        [[0, 1, 1, 0],
         [1, 0, 0, 1],
         [1, 0, 0, 1],
         [0, 1, 1, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[1, 0, 0, 1],
         [0, 1, 1, 0],
         [0, 1, 1, 0],
         [1, 0, 0, 1]]]])


dice_baldder1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1, activation=None)
dice_baldder2 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1e-5, activation=None)
print('smooth=1 : dice={:.4}'.format(dice_baldder1.item()))
print('smooth=1e-5 : dice={:.4}'.format(dice_baldder2.item()))

# 输出结果
smooth=1 : dice=1.050
smooth=1e-5 : dice=1.0

我们最后预测的是一个3分类的分割图,第一类是baldder, 第二类是tumor, 第三类是背景。我们先假设bladder的预测pred和gt一样,计算bladder的dice值,发现当smooth=1的时候,dice偏高, 而smooth=1e-5时dice比较合理。

 

解决办法:我想这里应该更改代码的实现方式,用下面的计算公式替换之前的,因为之前加smooth的位置有问题。

# loss = 2 * (intersection + smooth) / (unionset + smooth)  # 之前的

loss = (2 * intersection + smooth) / (unionset + smooth)

替换后的dice如下:

def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    intersection = (pred_flat * gt_flat).sum(1)
    unionset = pred_flat.sum(1) + gt_flat.sum(1)
    loss = (2 * intersection + smooth) / (unionset + smooth)

    return loss.sum() / N

上面用到的测试数据进行验证结果如下:dice计算正确 

# smooth=1 : dice=1.0
# smooth=1e-5 : dice=1.0

第二种情况:预测的结果不在标签中 

如下面的代码,我们假设预测的pred中有一部分bladder,但gt中没有bladder,看计算出的dice值如何。

'''
    1 0 0= bladder
    0 1 0 = tumor
    0 0 1= background 
    '''
    pred = torch.Tensor([[
        [[0, 1, 1, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[1, 0, 0, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]]])

    gt = torch.Tensor([[
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]],
        [[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]]])

    dice_baldder1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1, activation=None)
    dice_baldder2 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1e-5, activation=None)
    print('smooth=1 : dice={:.4}'.format(dice_baldder1.item()))
    print('smooth=1e-5 : dice={:.4}'.format(dice_baldder2.item()))

# 输出结果
smooth=1 : dice=0.3333
smooth=1e-5 : dice=5e-06

从结果可以看到,smooth=1时的dice值为0.3333;而 smooth=1e-5时的dice值接近于0,较为合理。

dice的另一种计算方式:这里参考肾脏肿瘤挑战赛提供的dice计算方法。

【Pytorch】 Dice系数与Dice Loss损失函数实现_第1张图片

 

def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * tp) / (2 * tp + fp + fn)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    tp = torch.sum(gt_flat * pred_flat, dim=1)
    fp = torch.sum(pred_flat, dim=1) - tp
    fn = torch.sum(gt_flat, dim=1) - tp
    loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)
    return loss.sum() / N

 

整理代码

def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    intersection = (pred_flat * gt_flat).sum(1)
    unionset = pred_flat.sum(1) + gt_flat.sum(1)
    loss = (2 * intersection + smooth) / (unionset + smooth)

    return loss.sum() / N



def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * tp) / (2 * tp + fp + fn)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    tp = torch.sum(gt_flat * pred_flat, dim=1)
    fp = torch.sum(pred_flat, dim=1) - tp
    fn = torch.sum(gt_flat, dim=1) - tp
    loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)
    return loss.sum() / N


# v2的另一种代码写法
def diceCoeffv3(pred, gt, eps=1e-5, activation='sigmoid'):
    r""" computational formula:
        dice = (2 * tp) / (2 * tp + fp + fn)
    """

    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d")

    pred = activation_fn(pred)

    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)

    tp = torch.sum((pred_flat != 0) * (gt_flat != 0), dim=1)
    fp = torch.sum((pred_flat != 0) * (gt_flat == 0), dim=1)
    fn = torch.sum((pred_flat == 0) * (gt_flat != 0), dim=1)
    # 转为float,以防long类型之间相除结果为0
    loss = (2 * tp + eps).float() / (2 * tp + fp + fn + eps).float()

    return loss.sum() / N

 

class SoftDiceLoss(nn.Module):
    __name__ = 'dice_loss'

    def __init__(self, activation='sigmoid'):
        super(SoftDiceLoss, self).__init__()
        self.activation = activation

    def forward(self, y_pr, y_gt):
        return 1 - diceCoeffv2(y_pr, y_gt, activation=self.activation)

代码测试:

if __name__ == '__main__':
    
    # shape = torch.Size([2, 3, 4, 4])
    # 模拟batch_size = 2
    '''
    1 0 0= bladder
    0 1 0 = tumor
    0 0 1= background 
    '''
    pred = torch.Tensor([[
        [[0, 1, 0, 0],
         [1, 0, 0, 1],
         [1, 0, 0, 1],
         [0, 1, 1, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 1, 1, 0],
         [0, 0, 0, 0]],
        [[1, 0, 1, 1],
         [0, 1, 1, 0],
         [0, 0, 0, 0],
         [1, 0, 0, 1]]],
        [
            [[0, 1, 0, 0],
             [1, 0, 0, 1],
             [1, 0, 0, 1],
             [0, 1, 1, 0]],
            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 1, 1, 0],
             [0, 0, 0, 0]],
            [[1, 0, 1, 1],
             [0, 1, 1, 0],
             [0, 0, 0, 0],
             [1, 0, 0, 1]]]
    ])

    gt = torch.Tensor([[
        [[0, 1, 1, 0],
         [1, 0, 0, 1],
         [1, 0, 0, 1],
         [0, 1, 1, 0]],
        [[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 1, 1, 0],
         [0, 0, 0, 0]],
        [[1, 0, 0, 1],
         [0, 1, 1, 0],
         [0, 0, 0, 0],
         [1, 0, 0, 1]]],
        [
            [[0, 1, 1, 0],
             [1, 0, 0, 1],
             [1, 0, 0, 1],
             [0, 1, 1, 0]],
            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 1, 1, 0],
             [0, 0, 0, 0]],
            [[1, 0, 0, 1],
             [0, 1, 1, 0],
             [0, 0, 0, 0],
             [1, 0, 0, 1]]]
    ])


    dice1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], activation=None)
    dice2 = diceCoeffv2(pred[:, 0:1, :], gt[:, 0:1, :], activation=None)
    dice3 = diceCoeffv3(pred[:, 0:1, :], gt[:, 0:1, :], activation=None)
    print(dice1, dice2, dice3)

# 输出
tensor(0.9333) tensor(0.9333) tensor(0.9333)

 

总结:上面是这几天对dice以及dice loss的一些思考和实现。

2020/6/8更新:SoftDiceLoss的计算方式修改。在实际项目种训练发现之前的loss计算方式不够准确,现在按类别计算dice,求平均之后得到loss。

class SoftDiceLossV2(_Loss):
    __name__ = 'dice_loss'

    def __init__(self, num_classes, activation='sigmoid', reduction='mean'):
        super(SoftDiceLossV2, self).__init__()
        self.activation = activation
        self.num_classes = num_classes

    def forward(self, y_pred, y_true):
        class_dice = []
        for i in range(1, self.num_classes):
            class_dice.append(diceCoeff(y_pred[:, i:i + 1, :], y_true[:, i:i + 1, :], activation=self.activation))
        mean_dice = sum(class_dice) / len(class_dice)
        return 1 - mean_dice

 

你可能感兴趣的:(图像分割)