import torch
import torch.nn as nn
import torch.nn.functional as F
class Focal_loss(nn.Module):
def __init__(self, weight=None, gamma=0):
super(Focal_loss, self).__init__()
self.weight = weight
self.gamma = gamma
self.eps = 1e-8
def forward(self, predict, target):
if self.weight!=None:
weights = self.weight.unsqueeze(0).unsqueeze(1).repeat(predict.shape[0], predict.shape[2], 1)
target_onehot = F.one_hot(target.long(), predict.shape[1])
if self.weight!=None:
weights = torch.sum(target_onehot * weights, -1)
input_soft = F.softmax(predict, dim=1)
probs = torch.sum(input_soft.transpose(2, 1) * target_onehot, -1).clamp(min=0.001, max=0.999)#此处一定要限制范围,否则会出现loss为Nan的现象。
focal_weight = (1 + self.eps - probs) ** self.gamma
if self.weight!=None:
return torch.sum(-torch.log(probs) * weights * focal_weight) / torch.sum(weights)
else:
return torch.mean(-torch.log(probs) * focal_weight)
使用Focal loss:
input = torch.rand(1, 2, 2)
target = torch.tensor([[0, 1]]).repeat(input.shape[0], 1)
FL= Focal_loss(gamma=0.5)
print('focal loss: ', FL(input, target))
一定要限制probs的范围,否则训练过程可能出现loss为Nan的现象:
probs = torch.sum(input_soft.transpose(2, 1) * target_onehot, -1).clamp(min=0.001, max=0.999)#此处一定要限制范围,否则会出现loss为Nan的现象。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [N, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, input.cpu(), 1)
return result
class BinaryDiceLoss(nn.Module):
"""Dice loss of binary class
Args:
smooth: A float number to smooth loss, and avoid NaN error, default: 1
p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
predict: A tensor of shape [N, *]
target: A tensor of shape same with predict
Returns:
Loss tensor according to arg reduction
Raise:
Exception if unexpected reduction
"""
def __init__(self, smooth=1, p=1):
super(BinaryDiceLoss, self).__init__()
self.smooth = smooth
self.p = p
def forward(self, predict, target):
assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
predict = predict.contiguous().view(predict.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1)
num = torch.sum(torch.mul(predict, target))*2 + self.smooth
den = torch.sum(predict.pow(self.p) + target.pow(self.p)) + self.smooth
dice = num / den
loss = 1 - dice
return loss
class DiceLoss(nn.Module):
"""Dice loss, need one hot encode input
Args:
weight: An array of shape [num_classes,]
ignore_index: class index to ignore
predict: A tensor of shape [N, C, *]
target: A tensor of same shape with predict
other args pass to BinaryDiceLoss
Return:
same as BinaryDiceLoss
"""
def __init__(self, weight=None, ignore_index=None, **kwargs):
super(DiceLoss, self).__init__()
self.kwargs = kwargs
self.weights = weight
self.ignore_index = ignore_index
def forward(self, predict, target):
nclass = predict.shape[1]
target = torch.nn.functional.one_hot(target.long(), nclass)#[1, 4]->[1, 4, 5]
#target = torch.transpose(torch.transpose(target, 1, 3), 2, 3)
target = torch.transpose(target, 1, 2)
assert predict.shape == target.shape, 'predict & target shape do not match'
dice = BinaryDiceLoss(**self.kwargs)
total_loss = 0
predict = F.softmax(predict, dim=1)
for i in range(target.shape[1]):
if i != self.ignore_index:
dice_loss = dice(predict[:, i], target[:, i])
if self.weights is not None:
assert self.weights.shape[0] == target.shape[1], \
'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0])
dice_loss *= self.weights[i]
total_loss += dice_loss
return total_loss/target.shape[1] if self.weights is None else total_loss/(torch.sum(self.weights))
使用Dice loss:
input = torch.tensor([[0.1, 0.2], [0.3, 0.4]]).unsqueeze(0).unsqueeze(1).repeat(1, 5, 1, 1)
target = torch.tensor([[0, 0], [2, 4]]).unsqueeze(0)
dice_loss = DiceLoss(weight=torch.tensor([1.2, 1.3, 1.4, 1, 0.9]))
print(dice_loss(input.view(1, 5, 4), target.view(1, 4)))