class BinaryDiceLoss(nn.Module):
def __init__(self):
super(BinaryDiceLoss, self).__init__()
def forward(self, input, targets):
# 获取每个批次的大小 N
N = targets.size()[0]
# 平滑变量
smooth = 1
# 将宽高 reshape 到同一纬度
input_flat = input.view(N, -1)
targets_flat = targets.view(N, -1)
# 计算交集
intersection = input_flat * targets_flat
N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
# 计算一个批次中平均每张图的损失
loss = 1 - N_dice_eff .sum() / N
return loss
class BalancedBCELoss(nn.Module):
def __init__(self,target):
super(BalancedBCELoss,self).__init__()
self.eps=1e-6
#weight = torch.tensor([torch.reciprocal(torch.sum(target==0).float()+self.eps),torch.reciprocal(torch.sum(target==1).float()+self.eps),torch.reciprocal(torch.sum(target==2).float()+self.eps),torch.reciprocal(torch.sum(target==3).float()+self.eps)])
weight = torch.tensor([torch.reciprocal(torch.sum(target==0).float()+self.eps),torch.reciprocal(torch.sum(target==1).float()+self.eps)])
self.criterion = nn.CrossEntropyLoss(weight)
def forward(self, output,target):
#output[b,class,h,w]-not use softmax, target[b,h,w]
loss = self.criterion(output,target)
return loss
具体使用
#label[b,h,w], maksR_modelA[b,class,h,w]-not use softmax
seg_criterian = BalancedBCELoss(label)
seg_criterian = seg_criterian.cuda()
bce_loss = seg_criterian(maksR_modelA, label)
参考文献:W-Net: A Deep Model for Fully Unsupervised Image Segmentation
原理:the normalized cut (Ncut) [29] as a global criterion for the segmentation:
However, since the argmax function is nondifferentiable, it is impossible to calculate the corre-
sponding gradient during backpropagation. Instead, we define a soft version of the Ncut loss which is differentiable so that we can update gradients during backpropagation:
pytroch代码
class NCutLoss2D(nn.Module):
r"""Implementation of the continuous N-Cut loss, as in:
'W-Net: A Deep Model for Fully Unsupervised Image Segmentation', by Xia, Kulis (2017)"""
def __init__(self, radius: int = 5, sigma_1: float = 4, sigma_2: float = 10):
r"""
:param radius: Radius of the spatial interaction term
:param sigma_1: Standard deviation of the spatial Gaussian interaction
:param sigma_2: Standard deviation of the pixel value Gaussian interaction
"""
super(NCutLoss2D, self).__init__()
self.radius = radius
self.sigma_1 = sigma_1 # Spatial standard deviation
self.sigma_2 = sigma_2 # Pixel value standard deviation
def forward(self, labels: Tensor, inputs: Tensor) -> Tensor:
r"""Computes the continuous N-Cut loss, given a set of class probabilities (labels) and raw images (inputs).
Small modifications have been made here for efficiency -- specifically, we compute the pixel-wise weights
relative to the class-wide average, rather than for every individual pixel.
:param labels: Predicted class probabilities
:param inputs: Raw images
:return: Continuous N-Cut loss
"""
num_classes = labels.shape[1]
kernel = gaussian_kernel(radius=self.radius, sigma=self.sigma_1, device=labels.device.type)
loss = 0
for k in range(num_classes):
# Compute the average pixel value for this class, and the difference from each pixel
class_probs = labels[:, k].unsqueeze(1)
class_mean = torch.mean(inputs * class_probs, dim=(2, 3), keepdim=True) / \
torch.add(torch.mean(class_probs, dim=(2, 3), keepdim=True), 1e-5)
diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1)
# Weight the loss by the difference from the class average.
weights = torch.exp(diff.pow(2).mul(-1 / self.sigma_2 ** 2))
# Compute N-cut loss, using the computed weights matrix, and a Gaussian spatial filter
numerator = torch.sum(class_probs * class_probs * weights)
denominator = torch.sum(class_probs * weights)
#numerator = torch.sum(class_probs * F.conv2d(class_probs * weights, kernel, padding=self.radius))
#denominator = torch.sum(class_probs * F.conv2d(weights, kernel, padding=self.radius))
loss += nn.L1Loss()(numerator / torch.add(denominator, 1e-6), torch.zeros_like(numerator))
return num_classes - loss
具体使用:将经过softmax预测图与输入图像计算loss,类似于聚类原理
#input[b,1,h,w], predR_map[b,class,h,w]-after softmax
loss = softcut(predR_map, input)