Ref:
- https://openaccess.thecvf.com/content_ICCV_2017/papers/Lin_Focal_Loss_for_ICCV_2017_paper.pdf
- https://zhuanlan.zhihu.com/p/80594704
- https://arxiv.org/pdf/1811.05181.pdf
背景
工作中处理二分类问题,数据大多是长尾分布,即正样本远小于负样本。一般来说,通过调整阈值(置信度),就可以满足上线需求。但总是有一些正样本,得分较低,希望找到一些办法,提高这些得分很低的正例分数,且负样本得分不被拉高太多。
模型通过梯度更新进行训练,实际应用中,大部分的样本是容易区分的,而这些样本贡献了主要的loss,模型偏向于这些样本,在部分难区分的样本上效果不好。
所以,为提高模型效果,要解决两个问题:
- 如何处理样本不均衡问题?
- 如何有效处理{正难,负难}的样本?
Focal Loss
主要应用在目标检测,实际应用范围很广。
分类问题中,常见的loss是cross-entropy:
为了解决正负样本不均衡,乘以权重:
一般根据各类别数据占比,对进行取值,即当class_1占比为30%时,。
我们希望模型能更关注容易错分的数据,反向思考,就是让模型别那么关注容易分类的样本。因此,Focal Loss的思路就是,把高置信度的样本损失降低。
多分类样本:
不同取值情况如下图:
模型是如何通过控制损失的衰减的呢?
当样本被误分类时,p很小,很大,loss不怎么受影响。当样本被正确分类,p很大,变小,loss衰减。
比如:当为1,p为0.98时,,这个容易分类的样本,损失和cross-entropy相比,衰减了100倍。
代码
# 二分类
class BCEFocalLoss(torch.nn.Module):
"""
https://github.com/louis-she/focal-loss.pytorch/blob/master/focal_loss.py
二分类的Focalloss alpha 固定
"""
def __init__(self, gamma=2, alpha=0.25, reduction='sum'):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, preds, targets):
"preds:[B,C],targets:[B]"
pt = torch.sigmoid(preds)
pt = pt.clamp(min=0.0001,max = 1.0) # 概率过低,logpt后,loss返回nan
# 我在gpu上使用时,不加.to(targets.device),报错
targets = torch.zeros(targets.size(0),2).to(targets.device).scatter_(1,targets.view(-1,1),1)
loss = - self.alpha * (1 - pt) ** self.gamma * targets * torch.log(pt) - \
(1 - self.alpha) * pt ** self.gamma * (1 - targets) * torch.log(1 - pt)
if self.reduction == 'elementwise_mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss
# 多分类
class FocalLoss(nn.Module):
"""
Ref: https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py
FL(pt) = -alpha_t(1-pt)^gamma log(pt)
alpha: 类别权重,常数时,类别权重为:[alpha,1-alpha,1-alpha,...];列表时,表示对应类别权重
gamma: 难易分类的样本权重,使得模型更关注难分类的样本
优点:帮助区分难分类的不均衡样本数据
"""
def __init__(self, num_classes, alpha=0.25,gamma=2,reduce=True):
super(FocalLoss,self).__init__()
self.num_classes = num_classes
self.gamma = gamma
self.reduce = reduce
if alpha is None:
self.alpha = torch.ones(self.num_classes,1)
else:
self.alpha = torch.zeros(num_classes)
self.alpha[0] = alpha
self.alpha[1:] += (1-alpha)
def forward(self,preds,targets):
"preds:[B,C],targets:[B]"
preds = preds.view(-1,preds.size(-1)) #[B,C]
self.alpha = self.alpha.to(preds.device)
logpt = F.log_softmax(preds,dim=1)
pt = F.softmax(preds).clamp(min=0.0001,max=1.0)
logpt = logpt.gather(1,targets.view(-1,1)) # 对应类别值
pt = pt.gather(1,targets.view(-1,1))
self.alpha = self.alpha.gather(0,targets.view(-1))
loss = -(1-pt) **self.gamma *logpt
loss = self.alpha*loss.t()
if self.reduce:
return loss.mean()
else:
return loss.sum()
GHM - gradient harmonizing mechanism
Focal Loss对容易分类的样本进行了损失衰减,让模型更关注难分样本,并通过和进行调参。
GHM提到:
- 有一部分难分样本就是离群点,不应该给他太多关注;
- 梯度密度可以直接统计得到,不需要调参。
GHM认为,类别不均衡可总结为难易分类样本的不均衡,而这种难分样本的不均衡又可视为梯度密度分布的不均衡。假设一个正样本被正确分类,它就是正易样本,损失不大,模型不能从中获益。而一个错误分类的样本,更能促进模型迭代。实际应用中,大量的样本都是属于容易分类的类型,这种样本一个起不了太大作用,但量级过大,在模型进行梯度更新时,起主要作用,使得模型朝这类数据更新。
- 图示左,样本梯度分布。
梯度模长(gradient norm)在很小和很大时,密度较大。前者,表示了大量容易分类的样本,所以梯度很低。而后者,文中认为是离群点,即便模型收敛,损失仍然很大。 - 图示中,经过修正后的梯度分布。
和CE,FL相比,GHM-C根据梯度密度,大量容易分类的样本和离群点的累计梯度被降级,达到样本均衡,使得模型更加有效稳定。 - 图示右,样本集梯度贡献。
经过GHM-C的梯度密度调整,各种难易分类的样本分布更加平滑。
简而言之:Focal Loss是从置信度p来调整loss,GHM通过一定范围置信度p的样本数来调整loss。
梯度模长
梯度模长:原文中用表示真实标签,这里统一符号,用y表示:
推理:
则:
梯度密度(Gradient Density)
梯度模长分布不均,引入梯度密度:
在N个样本中,梯度模长分布在范围的个数:
区间长度:
梯度密度协调参数:
上式分母,可视为对附近样本进行归一化。如果梯度分布均匀,则,如果密度过高,则意味着要降级处理。
GHM loss计算
代码
def _expand_binary_labels(labels,label_weights,label_channels):
bin_labels = labels.new_full((labels.size(0), label_channels),0)
inds = torch.nonzero(labels>=1).squeeze()
if inds.numel() >0:
bin_labels[inds,labels[inds]] = 1
bin_label_weights = label_weights.view(-1,1).expand(label_weights.size(0),label_channels)
return bin_labels, bin_label_weights
class GHMC(nn.Module):
"""GHM Classification Loss.
Ref:https://github.com/libuyu/mmdetection/blob/master/mmdet/models/losses/ghm_loss.py
Details of the theorem can be viewed in the paper
"Gradient Harmonized Single-stage Detector".
https://arxiv.org/abs/1811.05181
Args:
bins (int): Number of the unit regions for distribution calculation.
momentum (float): The parameter for moving average.
use_sigmoid (bool): Can only be true for BCE based loss now.
loss_weight (float): The weight of the total GHM-C loss.
"""
def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0,alpha=None):
super(GHMC, self).__init__()
self.bins = bins
self.momentum = momentum
edges = torch.arange(bins + 1).float() / bins
self.register_buffer('edges', edges)
self.edges[-1] += 1e-6
if momentum > 0:
acc_sum = torch.zeros(bins)
self.register_buffer('acc_sum', acc_sum)
self.use_sigmoid = use_sigmoid
if not self.use_sigmoid:
raise NotImplementedError
self.loss_weight = loss_weight
self.label_weight = alpha
def forward(self, pred, target, label_weight =None, *args, **kwargs):
"""Calculate the GHM-C loss.
Args:
pred (float tensor of size [batch_num, class_num]):
The direct prediction of classification fc layer.
target (float tensor of size [batch_num, class_num]):
Binary class target for each sample.
label_weight (float tensor of size [batch_num, class_num]):
the value is 1 if the sample is valid and 0 if ignored.
Returns:
The gradient harmonized loss.
"""
# the target should be binary class label
# if pred.dim() != target.dim():
# target, label_weight = _expand_binary_labels(
# target, label_weight, pred.size(-1))
# 我的pred输入为[B,C],target输入为[B]
target = torch.zeros(target.size(0),2).to(target.device).scatter_(1,target.view(-1,1),1)
# 暂时不清楚这个label_weight输入形式,默认都为1
if label_weight is None:
label_weight = torch.ones([pred.size(0),pred.size(-1)]).to(target.device)
target, label_weight = target.float(), label_weight.float()
edges = self.edges
mmt = self.momentum
weights = torch.zeros_like(pred)
# gradient length
# sigmoid梯度计算
g = torch.abs(pred.sigmoid().detach() - target)
# 有效的label的位置
valid = label_weight > 0
# 有效的label的数量
tot = max(valid.float().sum().item(), 1.0)
n = 0 # n valid bins
for i in range(self.bins):
# 将对应的梯度值划分到对应的bin中, 0-1
inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
# 该bin中存在多少个样本
num_in_bin = inds.sum().item()
if num_in_bin > 0:
if mmt > 0:
# moment计算num bin
self.acc_sum[i] = mmt * self.acc_sum[i] \
+ (1 - mmt) * num_in_bin
# 权重等于总数/num bin
weights[inds] = tot / self.acc_sum[i]
else:
weights[inds] = tot / num_in_bin
n += 1
if n > 0:
# scale系数
weights = weights / n
loss = F.binary_cross_entropy_with_logits(
pred, target, weights, reduction='sum') / tot
return loss * self.loss_weight