最近有人问我focalloss是什么原理,看不懂,大多数网文看了还更朦胧,于是我抽空整理了一篇文章帮助大家理解。
Focal loss解决了什么问题?降低易分样本权重、增加难分样本的权重训练。
对于难易样本,CE Loss 是一致的对待,当累加了大量易分样本的 loss 后,难分样本数量少,其loss贡献几乎被完全淹没,造成易分样本梯度占据了主导,网络偏向于学习这些很容易分类正确的, 对于少数真正感兴趣的难样本却难以学习,而训练到一定时候要想提升模型效果还得想办法依靠这些难分样本。Focal Loss 主要关注解决在训练过程中大量的易分样本淹没检测器的问题,类别不平衡的问题解决方式同理。
(1)难易样本不均衡
(2)不同类别不均衡
Focal loss是怎么改进CrossEntropy Loss的呢?
二分类交叉熵损失(CrossEntropy Loss):
Focal Loss:
其中,y是真实标签,y'是预测出来的概率值,对于二分类时,y=1为正样本,y=0为负样本,y'>>0.5为预测到的正样本,y'<<0.5为预测到的负样本,y'在0.5附近的为难分样本,可以把正样本想象成要检测的目标,负样本想象成背景,难分样本就是那种人类知道正负但模型一看就晕了傻傻蒙圈分不清的样本。
Focal loss 核心参数有两个,一个是γ,一个是α。简化focal loss:
1、对于难度权重调节部分(不考虑α): γ利用幂函数的快速放缩特性动态调节简单样本权重降低的速率。对于正样本y=1,根据上面公式,当样本为易分样本,pt很大,随着pt→1,(1-pt)接近于0,加之γ次方后易分样本的权重被显著降低降低,loss受到明显缩小,相反,如果在另一个极端,pt值很小,(1-pt)接近于 1,loss几乎不受影响,模型犯这么严重错误的情况很少,可以忽略不计,否则肯定模型或训练数据出错了,而难分样本的pt在0.5附近,显然其loss被缩小的程度远远小于易分样本;同理,对于负样本y=0,根据FL loss另一半公式,当样本为易分样本,pt很小,加之γ次方后loss明显缩小,相反,如果在另一个极端,pt值接近于1,loss几乎不受影响,模型犯这么严重错误的情况也很少,可以忽略不计,而难分样本的pt在0.5附近,显然其loss被缩小的程度远远小于易分样本。整体效果来看,通过显著减小了易分类样本的loss使模型在训练时更专注于难分类的样本,从而提升难分样本的检出率。值得说明的是,γ越大幂函数放缩程度越大,对困难样本的重视程度越大,即越专注于比较困难的样本,但γ太大也会把易分样本的检测能力给放缩消失一部分,因此放缩要适度,且不要贪杯喔~,作者建议在 (0.5, 10.0) 范围尝试。
公式中采用(1-pt)γ:共同的作用平衡难易样本的重要性,实验发现γ=2效果最优。
2、对于类别权重调节部分(不考虑(1-pt)γ):α用于平衡不同类别样本的重要性,实际训练中正样本很少,所以作者希望通过提高α来提高正样本的重要性,然而调节α的实验效果提升并不明显。
因此,FL损失函数训练过程关注对象的排序为正难>负难>正易>负易:
3、当γ=0,α=0.5时,FL退化为CE损失,即FL=0.5*CE。
Focal Loss最初是解决onestage目标检测问题而提出的,后来大量应用在anchorfree目标检测算法中,实际在图片分类中也是可以使用的。
二分类的情况参考下面代码:
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
上面例子中的类别只有正负样本两类,如果应用到多分类情况参考下面代码,将sigmoid换成softmax预测:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
#print(class_mask)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1)
log_p = probs.log()
#print('probs size= {}'.format(probs.size()))
#print(probs)
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
#print('-----bacth_loss------')
#print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
参考文章:[机器学习] XGBoost 自定义损失函数-FocalLoss_VinkinTsang的博客-CSDN博客_xgboost自定义损失函数
https://www.jianshu.com/p/30043bcc90b6