1.原理
Focal Loss最先应用在目标检测任务上,它的提出主要是为了解决正负样本数量不平衡(有些地方叫做类别不平衡)以及难易分样本不均衡的问题。
在讲解之前,小编先谈谈何为难分样本和易分样本。
难分样本:指的是实际为正样本,但模型预测为正类的置信度很低或者说实际为负样本,但模型 预测为负类的置信度很低。
易分样本:指的是实际为正样本,且模型预测为正类的置信度很高或者说实际为负样本,且模型 预测为负类的置信度很高。
在进行模型训练过程中,相信大家都会发现一个实际存在的现象:假设正负样本均衡(一般对模型来说易分样本的数量会远远多于难分样本的数量),那么(a)难分样本:错分的样本或者正确分类但置信度很低(0.5附近)的样本,其单个样本的损失很高,但是所有难分样本的损失之和相对于整体损失来说占比却很低;(b)易分样本:正确分类且置信度很高(接近1)的样本,其单个损失很低,但所有易分样本的损失之和相对于整体损失来说占比很高。也就是说易分样本的损失之和将主导整体损失,这样将导致模型训练过程中该背景下的损失函数反向传播时对参数的更新并不会改善模型的预测能力,模型对难分样本的预测能依旧很差。而易分样本本身就能很好的被模型识别,难分样本才是模型最该关注的点。
Focal Loss就很好地解决了这个问题,它的具体公式如下:
其中代表模型预测某类别的概率(即置信度);是用来平衡正负样本数量的,样本数量多的赋予更小的值,样本数量少的赋予更大的值;是用来调节难分易分样本不均衡问题的,一般取,对易分样本的损失进行一个幂函数的降低。乘上这项可以使模型更加关注于难分样本,比如:假设模型预测某类别的置信度,属于易分样本,当时,,相当于将易分样本损失缩小了1000倍,而另一类预测的置信度,属于难分样本,,相当于只将难分样本的损失缩小了0.64倍,难分样本的损失之和在整体损失的占比中明显提升,说明模型更关注于难分样本,有利于模型预测能力的提升。
2.代码实现
其实Focal Loss可以看作是由交叉熵损失函数改进来的,是样本类别真实的概率(即标签),是已知的,一般训练是需要对标签进行one-hot编码,,最终交叉熵损失函数变为,Pytorch中的F.Cross_entropy()函数对应的数学公式就是这个。
这里简单描述一下F.Cross_entropy(x,y)函数的计算过程,代码实现时会用到这个函数:
①第一个参数x输入到模型中进行前向传播后进行softmax操作,使模型输出结果在0-1之间,结果记为x_soft;
②对x_soft做对数运算并取相反数,结果记为x_soft_log;
③对y进行one-hot编码,根据上面的交叉熵公式与x_soft_log进行点乘,只有与元素1对应的位置点乘后有非零值,因为乘的是1,所以点乘结果还是x_soft_log,其它位置是与0点乘,点乘结果都为0。
Focal Loss的代码实现如下:
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, size_average=True, ignore_index=255):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.size_average = size_average
def forward(self, inputs, targets):
# F.cross_entropy(x,y)工作过程就是(Log_Softmax+NllLoss):①对x做softmax,使其满足归一化要求,结果记为x_soft;②对x_soft做对数运算
# 并取相反数,记为x_soft_log;③对y进行one-hot编码,编码后与x_soft_log进行点乘,只有元素为1的位置有值而且乘的是1,
# 所以点乘后结果还是x_soft_log
# 总之,F.cross_entropy(x,y)对应的数学公式就是CE(pt)=-1*log(pt)
ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=self.ignore_index)
pt = torch.exp(-ce_loss) # pt是预测该类别的概率,要明白F.cross_entropy工作过程就能够理解
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()