Focal Loss 与 Progressive Focal Loss的pytorch实现

Focal Loss 与 Progressive Focal Loss的pytorch实现

  • 介绍
  • 代码

介绍

懂的都懂,不懂的我说了你也不懂,懂的人都已经获利上岸了,不懂的人则永远不懂

代码

import torch
import torch.nn as nn
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, progressive=False):
        super(FocalLoss, self).__init__()
        # progressive: 使用 progressive focal loss,懂的都懂,不懂的我也没办法,好用就行
        self.alpha = alpha
        self.gamma = gamma
        self.progressive = progressive



    def forward(self, inputs, targets):
        # inputs: n, c
        # 公式: - \alpha * (1-p)^\gamma * log_p
        # inputs为没经过softmax的原始网络输出,target一般为一维的label: (batch, )


        # sigmoid 可换成 softmax
        pred = torch.clamp(torch.sigmoid(inputs), min=1e-4, max=1 - 1e-4) # (batch, num_class)
        num_classes = inputs.size(1)
        device = targets.device
        class_range = torch.arange(0, num_classes, dtype=targets.dtype, device=device).unsqueeze(0)
        t = targets.unsqueeze(1)

        if self.progressive:
            # 确定每个类别
            cls_same = (t == class_range).float()  # (n, num_cls)
            # 预测的数据
            pred_nogard = pred.detach().data
            # 每个类别正类的数量
            cls_pos_num = torch.sum(cls_same, dim=0)  # (num_cls)
            # 每个类别预测概率的求和
            pred_pos_sum = torch.sum(cls_same * pred_nogard, dim=0)  # (num_cls)
            # 每个类别的gamma
            cls_gamma = - torch.log(pred_pos_sum / (cls_pos_num + 1e-7))  # (num_cls)
            cls_gamma = torch.clamp(cls_gamma, min=1.5, max=2.5)

            # expand gamma
            cur_gamma = cls_same * cls_gamma  # (n, num_cls)
            # 得到alpha
            cur_alpha = 0.5 / (cur_gamma + 1e-7)
            cur_alpha = torch.clamp(cur_alpha, min=0.2, max=0.3)  # (n, num_cls)

            # 对每个大类进行focal loss
            term1 = (1 - pred) ** cur_gamma * torch.log(pred)
            term2 = pred ** cur_gamma * torch.log(1 - pred)
            loss_org = -term1 * cur_alpha * cls_same - term2 * (1 - cur_alpha) * ((t != class_range) * (t >= 0)).float()
            
        else:
            term1 = (1 - pred) ** self.gamma * torch.log(pred)
            term2 = pred ** self.gamma * torch.log(1 - pred)
            loss_org = -(t == class_range).float() * term1 * self.alpha - (
                        (t != class_range) * (t >= 0)).float() * term2 * (
                           1 - self.alpha)

        return loss_org.mean()

你可能感兴趣的:(计算机视觉)