【PyTorch】Balanced_CE_loss 实现

Pytorch中Balance binary cross entropy自定义实现

balance binary cross entropy损失函数在分割任务中很有用,因为分割任务会遇到正负样本不均的问题,甚至在边缘的分割任务重,样本不均衡达到了很高的比例。

故此,个人在基于分割任务中,自实现了该损失函数,亲测有效!

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..builder import LOSSES
from .utils import weight_reduce_loss


def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
    # element-wise losses
    loss = F.cross_entropy(pred, label, reduction='none')
    # apply weights and do the reduction
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
    return loss


def _expand_binary_labels(labels, label_weights, label_channels):
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
    inds = torch.nonzero(labels >= 1).squeeze()
    if inds.numel() > 0:
        bin_labels[inds, labels[inds] - 1] = 1

    if label_weights is None:
        bin_label_weights = None
    else:
        bin_label_weights = label_weights.view(-1, 1).expand(
            label_weights.size(0), label_channels)

    return bin_labels, bin_label_weights


def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None):
    
    if pred.dim() != label.dim():
        label, weight = _expand_binary_labels(label, weight, pred.size(-1))

    # weighted element-wise losses
    if weight is not None:
        weight = weight.float()
    loss = F.binary_cross_entropy_with_logits(pred, label.float(), weight, reduction='none')
    # do the reduction for the weighted loss
    loss = weight_reduce_loss(loss, reduction=reduction, avg_factor=avg_factor)

    return loss


def balanced_mask_cross_entropy(pred, label, mask=None, negative_ratio=3.0, eps=1e-10):
    positive = label.byte()
    negative = (1-label).byte()
    positive_count = int(positive.float().sum())
    negative_count = min(int(negative.float().sum()), int(positive_count * negative_ratio))
    loss = F.binary_cross_entropy(pred, label, reduction='none')[:,0,:,:]
    positive_loss = loss * positive.float()
    negative_loss = loss * negative.float()
    negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)

    balance_loss = (positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + eps)
    return balance_loss



@LOSSES.register_module()
class BalancedCrossEntropyLoss(nn.Module):

    def __init__(self,
                 negative_ratio=3.0,
                 eps=1e-10,
                 loss_weight=1.0):
        super(BalancedCrossEntropyLoss, self).__init__()
        self.negative_ratio = negative_ratio
        self.eps = eps
        self.loss_weight = loss_weight
        self.cls_criterion = balanced_mask_cross_entropy

    def forward(self,
                pred,
                label,
                mask=None,
                **kwargs):
        
        loss_cls = self.loss_weight * self.cls_criterion(
                pred, label, mask=None, negative_ratio=self.negative_ratio, eps=self.eps, **kwargs
            )

        return loss_cls

你可能感兴趣的:(深度学习框架,pytorch,深度学习,人工智能)