Pytorch实现多分类问题样本不均衡的权重损失函数 FocusLoss

  1. 原理:

    <点击跳转至原理解释>

  2. 实现代码,可直接使用 logitslabel 作为输入参数

    import torch
    import torch.nn.functional as F
    import torch.nn as nn
    
    class Focal_Loss(nn.Module):
        def __init__(self, weight, gamma=2):
            super(Focal_Loss,self).__init__()
            self.gamma = gamma
            self.weight = weight        # 是tensor数据格式的列表
    
        def forward(self, preds, labels):
            """
            preds:logist输出值
            labels:标签
            """
            preds = F.softmax(preds,dim=1)
            print(preds)
            eps = 1e-7
    
            target = self.one_hot(preds.size(1), labels)
            print(target)
    
            ce = -1 * torch.log(preds+eps) * target
            print(ce)
            floss = torch.pow((1-preds), self.gamma) * ce
            print(floss)
            floss = torch.mul(floss, self.weight)
            print(floss)
            floss = torch.sum(floss, dim=1)
            print(floss)
            return torch.mean(floss)
    
        def one_hot(self, num, labels):
            one = torch.zeros((labels.size(0),num))
            one[range(labels.size(0)),labels] = 1
            return one
    
  3. 参数说明

    1. 初始化类时,需要传入 a 列表,类型为tensor,表示每个类别的样本占比的反比,比如5分类中,有某一类占比非常多,那么就设置为小于0.2,即相应的权重缩小,占比很小的类,相应的权重就要大于0.2

      lf = Focal_Loss(torch.tensor([0.2,0.2,0.2,0.2,0.2]))
      
    2. 使用时,logits 是神经网络的输出,不用计算softmax,label是torchvision类自动生成的标签

      loss = lf(logits,label)
      
  4. 例子,这里 logits 为(16*5)的tensor,表示批大小为16,5分类;label为每个样本的真实标签类别,对应 logits 的下标,是一个16维的tensor向量

    logits = torch.tensor([[-2.7672,  3.6104, -7.4242, -3.2486, -3.1323],
                           [-2.4270,  3.1833, -5.9394, -2.4592, -3.2292],
                           [-2.5986,  3.3626, -6.7340, -2.8639, -3.1553],
                           [-2.6206,  3.4201, -6.8754, -2.9308, -3.1507],
                           [-2.8307,  3.7070, -7.6975, -3.3924, -3.1318],
                           [-2.5776,  3.3316, -6.6595, -2.8187, -3.1542],
                           [-2.8930,  3.7982, -7.9322, -3.5327, -3.1210],
                           [-2.5489,  3.3580, -6.5229, -2.7590, -3.1912],
                           [-1.5628,  1.8362, -1.8254, -0.3083, -3.5928],
                           [ 0.2434, -4.9000,  1.1150,  2.7505, -1.0390],
                           [-2.6877,  3.5686, -7.1178, -3.0847, -3.1617],
                           [-2.6847,  3.5191, -6.8264, -3.0083, -3.2041],
                           [-2.6137,  3.4025, -6.8965, -2.9250, -3.1396],
                           [-2.7505,  3.5840, -7.3340, -3.2035, -3.1435],
                           [-2.7030,  3.5163, -7.1549, -3.1002, -3.1424],
                           [-2.6661,  3.4580, -7.0481, -3.0258, -3.1365]])
    
    label = torch.tensor([1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 1, 1, 3, 1, 1, 1])
    
    lf = Focal_Loss(torch.tensor([0.2,0.2,0.2,0.2,0.2]))
    
    loss = lf(logits,label)
    print('loss:', loss)
    

    输出结果

    loss: tensor(0.1902)
    

你可能感兴趣的:(pytorch,python,人工智能,图像处理)