FocalLoss 带mask的多分类 代码实现

关于介绍FocalLoss的博客很多,这里做一个简单的总结,并且实现了一个多分类的数据不均衡的FocalLoss。

FocalLoss用来解决的问题

FocalLoss这个损失函数是在目标检测领域(由Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár提出) 针对one-stage的目标检测框架(例如SSD, YOLO)中正(前景)负(背景)样本极度不平均,负样本loss值主导整个梯度下降, 正样本占比小, 导致模型只专注学习负样本上。

FocalLoss是基于CrossEntropyLoss 改进的,专门处理分类数据不均衡的问题。

下面我们将会从CrossEntropy 开始,引入FocalLoss,然后做个对比

CrossEntropyLoss

我们先看

二分类的CrossEntropy:

L o s s = − y l o g ( p ) + ( 1 − y ) l o g ( 1 − p ) Loss = -y log(p) + (1-y)log(1-p) Loss=ylog(p)+(1y)log(1p)
也可以分开写 : 当y=1时,loss = -log§, 其他情况,loss = -log(1-p)

多分类 CrossEntropy

C E ( P t ) = − l o g ( P t ) CE(P_t) = -log(P_t) CE(Pt)=log(Pt)
其中 Pt是经过 softmax 转成 概率得到的
P t = e x , t ∑ j e x , j P_t = \frac {e^{x,t}}{\sum_j e^{x,j}} Pt=jex,jex,t

其实CE的公式简单明了, 但是当遇到样本极度不平均的情况下加和所有的loss值时, 正样本的loss值占比会非常小, 什么意思呢? 后面会举例讲解。

FocalLoss

F L ( p t ) = − α ( 1 − p t ) γ l o g ( p t ) FL(p_t) = - \alpha (1-p_t)^\gamma log(p_t) FL(pt)=α(1pt)γlog(pt)

FocalLoss 比原来的CrossEntropyLoss 多了一组权重系数 − α ( 1 − p t ) γ -\alpha (1-p_t)^\gamma α(1pt)γ, 同时多了两个超参数 α \alpha α γ \gamma γ.

对比

CrossEntropy随着p的变化

C E ( p t ) = − l o g ( p t ) CE(p_t) = -log(p_t) CE(pt)=log(pt) 当pt越大,loss越小, pt 越小,loss 越大。

FL 在不考虑 α \alpha α γ \gamma γ时, 即认为值为1

此时 FL为: F L ( p t ) = − ( 1 − p t ) l o g ( p t ) FL(p_t) = - (1-p_t)log(p_t) FL(pt)=(1pt)log(pt)

在不考虑a 和 gamma时 , 可以看出 当pt越大 -log(pt)越小,(1-pt)越小,所以权重越小。当pt越小,-log(pt)越大,(1-pt)越大,权重越大。 所以这个相当于对每个元素的loss 进行了权重调整。对于预测正确的 或者 概率高的 减少loss, 对于预测较差的结果,增大在loss中的分量。

FL只考虑 γ \gamma γ时, 即认为 α \alpha α值为1

此时FL为: F L ( p t ) = − ( 1 − p t ) γ l o g ( p t ) FL(p_t) = -(1-p_t)^\gamma log(p_t) FL(pt)=(1pt)γlog(pt)

假设gamma==2, 负样本prob= 0.95,
带入FL公式:-(1-0.95)^2log(0.95) = 0.00005569
带入CE公式:-log(0.95) = 0.02227639
结论:gamma能够有效降低负样本的loss值(简单样本的loss值),简单样本的概率越大效果越强。

FL考虑 α \alpha α γ \gamma γ

α \alpha α主要用来调和 正负样本权重比的

  • 假设负样本10000笔资料,probability(pt) = 0.95(简单样本)
  • 正样本10笔资料,probability(pt) = 0.05(困难样本)
    1. 直接带入CE公式: CE(pt) = -log(pt)
      负样本: log(pt) * 样本数 = 0.02227 * 10000 = 2227
      正样本: log(pt) * 样本数 = 1.30102 * 10 = 13.0102
      total loss = 2227 + 13.0102 = 2240
      loss中正样本占比:13.0102 / 2240 = 0.0058
    2. 带入FL 公式: FL(pt) = -a(1-pt)^r log(pt), 假设alpha=0.25, gamma=2
      负样本: (1-a)(1-pt)^r log(pt) * 样本数 = 0.75 * (1-0.95)^2 * 0.02227 * 10000 = 4.1756
      正样本: 0.25 *(1-0.05)^2 * 1.30102 * 1- = 2.935
      total loss = 4.175 + 2.935 = 7.110
      loss 中 正样本的占比: 2.935 /7.110 = 0.4127

结论: 经过比较, 我们算出CE正样本的值占总loss比例是0.0058, 而FocalLoss計算的正样本占比是0.4127,相差了71倍, 可以看出FL能有效提升正样本的loss占比

上面的例子中alpha取值为0.25, gamma=2,这是作者建议的最佳值(PS: gamma = 2, alpha = 0.25是经过作者不断尝试出的一般最佳值)
alpha 的0.25代表的是正样本, 所以负样本就会是1-0.25 = 0.75
就理论上来看,alpha值设定为0.75(因为正样本通常数量小)是比较合理, 但是毕竟还有gamma值在, 已经将负样本损失值降低许多,可理解为alpha和gamma相互牵制,alpha也不让正样本占比过大,因此最终设定为0.25, 如果有更好的理解欢迎留言一起讨论

FL中 α \alpha α γ \gamma γ的作用

  • gamma负责降低简单样本的损失值, 以解决加总后负样本loss值很大
  • alpha调和正负样本的不平均,如果设置0.25, 那么就表示负样本为0.75, 对应公式 1-alpha

代码

"""
# 2D 多分类的 FocalLoss (带mask版本)
如果是二分类问题,alpha 可以设置为一个值
如果是多分类问题,这里只能取list 外面要设置好list 并且长度要与分类bin大小一致,并且alpha的和要为1  
比如dist 的alpha=[0.02777]*36 +[0.00028] 这里是37个分类,设置前36个分类系数一样,最后一个分类权重系数小于前面的。
(其实多分类 alpha 不设置 也没问题的,因为gamma的已经在起作用了)
注意: 这里默认 input 已经经过网络最后一层的softmax了所以本身就是一个概率值。如果网络最后一层没有softmax,可以使用下面那个
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, device='cuda'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        if isinstance(alpha, (float, int, long)): 
        	"""
        	如果是二分类问题,alpha 可以设置为一个值
        	"""
            self.alpha = torch.Tensor([alpha, 1-alpha])
        if isinstance(alpha, list): 
        	"""
        	如果是多分类问题,这里只能取list 外面要设置好list 并且长度要与分类bin大小一致,并且alpha的和要为1  
        	比如dist 的alpha=[0.02777]*36 +[0.00028] 这里是37个分类,设置前36个分类系数一样,最后一个分类权重系数小于前面的。
        	"""
            self.alpha = torch.Tensor(alpha).to(device)

    def forward(self, input, target, mask):
        # input:  Batch_size * channel * L * L #      Batchsize * H * W * 37 个bin的概率
        # target: Batch_size * L * L (里面的值都是0-36)值
        # mask: Batch_size * L * L (里面是0 和1 值,0的地方不需要算loss)
        if torch.sum(mask.float()).item() == 0:
            return torch.tensor(0., requires_grad=True)
        input = input.permute(0,2,3,1).contiguous() # B * L * L * P
        # print(input.size(), target.size())
        pt = input.gather(-1, target.unsqueeze(-1).long()).squeeze() # 获取target对应的概率值 B * L * L 
        
        # 求出log值,然后乘以mask 得到 logpt
        logpt = torch.log(pt + 1e-6) * mask.float() # B * L * L # 乘上mask
        # 各个bin的概率值
        alpha = self.alpha.expand(input.shape) # B * L * L * P
        # 得到对应target 的各个bin的概率
        at = alpha.gather(-1, target.unsqueeze(-1).long()).squeeze()# alpha 里面存放 37个bin的权重比例 [0.02777]*36 +[0.00028]
        #FL(pt) = -a (1-pt)^r log(pt)   
        logpt = logpt * Variable(at) # B * L * L 
        loss = -1* (1-pt)**self.gamma * logpt 

        loss = loss.sum()/mask.float().sum()
        return loss
"""
# 多分类的 FocalLoss
如果是二分类问题,alpha 可以设置为一个值
如果是多分类问题,这里只能取list 外面要设置好list 并且长度要与分类bin大小一致,并且alpha的和要为1  
比如dist 的alpha=[0.02777]*36 +[0.00028] 这里是37个分类,设置前36个分类系数一样,最后一个分类权重系数小于前面的。

注意: 这里默认 input是没有经过softmax的,并且把shape 由H*W 2D转成1D的,然后再计算
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input) # 这里转成log(pt)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

参考:
FocalLoss
FocalLoss代码

你可能感兴趣的:(pytorch,笔记,深度学习,神经网络,pytorch)