样本不平衡 pytorch_Focal Loss的pytorch代码实现和分析

Focal Loss对于不平衡数据集和难易样本的学习是非常有效的。本文分析简单的源代码来加深对于Focal Loss的理解。闲话少说,进入正题。

首先需要加载pytorch的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision
import torchvision.transforms as F

from IPython.display import display
class FocalLoss(nn.Module):

    def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.ce = torch.nn.CrossEntropyLoss(weight=weight, reduction=reduction)

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

上面是Focal Loss的pytorch实现的核心代码。主要是使用torch.nn.CrossEntropyLoss来实现。代码中最核心的部分有两个部分: - torch.nn.CrossEntropyLoss - 将Cross entropy loss改为Focal loss

这里看看torch.nn.CrossEntropyLoss的注释

该函数的原型如下:

torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

上面的函数综合了torch.nn.LogSoftmaxtorch.nn.NLLLoss两个损失函数。(具体定义可以查看函数文档torch.nn.CrossEntropyLoss??)

其中weight用于样本均衡处理,给每一类样本一个权值。可以参考下面两个公式。

公式1: $$ text{loss}(x, class) = -logleft(frac{exp(x[class])}{sum_j exp(x[j])}right) = -x[class] + logleft(sum_j exp(x[j])right) $$ 上面的公式先使用torch.nn.LogSoftmax归一化输出,然后使用torch.nn.NLLLoss计算负对数,即熵。

公式2: $$ text{loss}(x, class) = weight[class] left(-x[class] + logleft(sum_j exp(x[j])right)right) $$

示例代码(CrossEntropyLoss):

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
output.backward()
loss = nn.CrossEntropyLoss()
# 要注意nn.CrossEntropyLoss的reduction参数的使用。如果是'none'的话,则会返回与批次数量相同的张量
# 如果是'mean',默认为'mean'话,则会取均值
focalloss_1 = FocalLoss(gamma=1)
focalloss_2 = FocalLoss(gamma=1, reduction='none')

input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)

output = loss(input, target)
output_focalloss1 = focalloss_1(input, target)
output_focalloss2 = focalloss_2(input, target)

display(output, output_focalloss1, output_focalloss2)
tensor(2.2966, grad_fn=)



tensor(2.0656, grad_fn=)



tensor(2.0632, grad_fn=)

公式1: $$ text{loss}(x, class) = left(1 - left(frac{exp(x[class])}{sum_j exp(x[j])}right)right)^gamma -logleft(frac{exp(x[class])}{sum_j exp(x[j])}right) = left(1 - left(frac{exp(x[class])}{sum_j exp(x[j])}right)right)^gamma(-x[class] + logleft(sum_j exp(x[j])right)) = -(1 - p)^gamma log(p) $$

其中$p=left(frac{exp(x[class])}{sum_j exp(x[j])}right)$

公式2: $$ text{loss}(x, class) = left(1 - left(frac{exp(x[class])}{sum_j exp(x[j])}right)right)^gamma weight[class] left(-x[class] + logleft(sum_j exp(x[j])right)right) = -w(1-p)^gamma log(p) $$

将CrossEntropyLoss改为FocalLoss

$$ -log(p) = nn.CrossEntropyLoss(input, target) $$ 因此 $$ p = torch.exp(-nn.CrossEntropyLoss(input, target)) $$

最终FocalLoss为 $$ focalloss = (1 - p)^gamma (-log(p)) $$

当然考虑到是mini-batch算法,因此最后一步取均值运算。

你可能感兴趣的:(样本不平衡,pytorch)