balance binary cross entropy损失函数在分割任务中很有用,因为分割任务会遇到正负样本不均的问题,甚至在边缘的分割任务重,样本不均衡达到了很高的比例。
故此,个人在基于分割任务中,自实现了该损失函数,亲测有效!
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from .utils import weight_reduce_loss
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
# element-wise losses
loss = F.cross_entropy(pred, label, reduction='none')
# apply weights and do the reduction
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
if label_weights is None:
bin_label_weights = None
else:
bin_label_weights = label_weights.view(-1, 1).expand(
label_weights.size(0), label_channels)
return bin_labels, bin_label_weights
def binary_cross_entropy(pred,
label,
weight=None,
reduction='mean',
avg_factor=None):
if pred.dim() != label.dim():
label, weight = _expand_binary_labels(label, weight, pred.size(-1))
# weighted element-wise losses
if weight is not None:
weight = weight.float()
loss = F.binary_cross_entropy_with_logits(pred, label.float(), weight, reduction='none')
# do the reduction for the weighted loss
loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)
return loss
def balanced_mask_cross_entropy(pred, label, mask=None, negative_ratio=3.0, eps=1e-10):
positive = label.byte()
negative = (1-label).byte()
positive_count = int(positive.float().sum())
negative_count = min(int(negative.float().sum()), int(positive_count * negative_ratio))
loss = F.binary_cross_entropy(pred, label, reduction='none')[:,0,:,:]
positive_loss = loss * positive.float()
negative_loss = loss * negative.float()
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + eps)
return balance_loss
@LOSSES.register_module()
class BalancedCrossEntropyLoss(nn.Module):
def __init__(self,
negative_ratio=3.0,
eps=1e-10,
loss_weight=1.0):
super(BalancedCrossEntropyLoss, self).__init__()
self.negative_ratio = negative_ratio
self.eps = eps
self.loss_weight = loss_weight
self.cls_criterion = balanced_mask_cross_entropy
def forward(self,
pred,
label,
mask=None,
**kwargs):
loss_cls = self.loss_weight * self.cls_criterion(
pred, label, mask=None, negative_ratio=self.negative_ratio, eps=self.eps, **kwargs
)
return loss_cls