之前在做一个分类的任务,受师兄指点,知道了这样一个损失,但是笔者的实验结果并不是很好,反而和过采样的结果相比略有下降,但是权当学习了
focal loss是何凯明大神在论文:Focal Loss for Dense Object Detection 中提出来用于目标检测任务的,但是本质也还是分类任务,所以就尝试把focal loss用在多分类任务上。
focal loss的形式如下:
可以看到这里有两个超参,是用来平衡样本数量的,相当于惩罚项,用来控制难分样本的挖掘,=1的时候就是我们平常使用的交叉熵损失,但是当取值增大时,易分样本的损失就会变小,难分样本的损失则相对来说较大。在论文中=0.25,=2,这是作者认为的最佳参数,如果需要调整,当增大时,的值需要略微减小(的取值范围是(0,1])(In general should be decreased slightly as is increased)。这里的就是predictions。
然后贴出笔者实现的tensorflow版本多分类focal loss:
focal_loss = tf.reduce_mean(- tf.reduce_sum(alpha * label * tf.pow(1-pred, gamma) * tf.log(pred), reduction_indices=[1]))
对于多分类,alpha的长度应该和类别数一致,直觉下应该是样本越多,alpha越小(多出好多超参...)。
如果有哪位大神通过focal loss得到较好的分类效果希望可以交流一下。