语义分割任务多分类Focal loss与Dice loss pytorch实现

1. Focal loss

import torch
import torch.nn as nn
import torch.nn.functional as F

class Focal_loss(nn.Module):
    def __init__(self, weight=None, gamma=0):
        super(Focal_loss, self).__init__()
        self.weight = weight
        self.gamma = gamma
        self.eps = 1e-8
    def forward(self, predict, target):
        if self.weight!=None:
            weights = self.weight.unsqueeze(0).unsqueeze(1).repeat(predict.shape[0], predict.shape[2], 1)
        target_onehot = F.one_hot(target.long(), predict.shape[1]) 
        if self.weight!=None:
            weights = torch.sum(target_onehot * weights, -1)
        input_soft = F.softmax(predict, dim=1)
        probs = torch.sum(input_soft.transpose(2, 1) * target_onehot, -1).clamp(min=0.001, max=0.999)#此处一定要限制范围,否则会出现loss为Nan的现象。
        focal_weight = (1 + self.eps - probs) ** self.gamma
        if self.weight!=None:
            return torch.sum(-torch.log(probs) * weights * focal_weight) / torch.sum(weights)
            return torch.mean(-torch.log(probs) * focal_weight)

使用Focal loss:

    input = torch.rand(1, 2, 2)
    target = torch.tensor([[0, 1]]).repeat(input.shape[0], 1)
    FL= Focal_loss(gamma=0.5)
    print('focal loss: ', FL(input, target))


2. Dice loss

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
        A tensor of shape [N, num_classes, *]
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, input.cpu(), 1)

    return result

class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        Loss tensor according to arg reduction
        Exception if unexpected reduction
    def __init__(self, smooth=1, p=1):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        predict = predict.contiguous().view(predict.shape[0], -1)
        target = target.contiguous().view(target.shape[0], -1)

        num = torch.sum(torch.mul(predict, target))*2 + self.smooth
        den = torch.sum(predict.pow(self.p) + target.pow(self.p)) + self.smooth

        dice = num / den
        loss = 1 - dice
        return loss

class DiceLoss(nn.Module):
    """Dice loss, need one hot encode input
        weight: An array of shape [num_classes,]
        ignore_index: class index to ignore
        predict: A tensor of shape [N, C, *]
        target: A tensor of same shape with predict
        other args pass to BinaryDiceLoss
        same as BinaryDiceLoss
    def __init__(self, weight=None, ignore_index=None, **kwargs):
        super(DiceLoss, self).__init__()
        self.kwargs = kwargs
        self.weights = weight
        self.ignore_index = ignore_index

    def forward(self, predict, target):
        nclass = predict.shape[1]
        target = torch.nn.functional.one_hot(target.long(), nclass)#[1, 4]->[1, 4, 5]
        #target = torch.transpose(torch.transpose(target, 1, 3), 2, 3)
        target = torch.transpose(target, 1, 2)

        assert predict.shape == target.shape, 'predict & target shape do not match'
        dice = BinaryDiceLoss(**self.kwargs)
        total_loss = 0
        predict = F.softmax(predict, dim=1)

        for i in range(target.shape[1]):
            if i != self.ignore_index:
                dice_loss = dice(predict[:, i], target[:, i])
                if self.weights is not None:
                    assert self.weights.shape[0] == target.shape[1], \
                        'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
                    dice_loss *= self.weights[i]
                total_loss += dice_loss

        return total_loss/target.shape[1] if self.weights is None else total_loss/(torch.sum(self.weights))

使用Dice loss:

input = torch.tensor([[0.1, 0.2], [0.3, 0.4]]).unsqueeze(0).unsqueeze(1).repeat(1, 5, 1, 1)
target = torch.tensor([[0, 0], [2, 4]]).unsqueeze(0)
dice_loss = DiceLoss(weight=torch.tensor([1.2, 1.3, 1.4, 1, 0.9]))
print(dice_loss(input.view(1, 5, 4), target.view(1, 4)))
