多分类focal loss及其tensorflow实现

之前在做一个分类的任务,受师兄指点,知道了这样一个损失,但是笔者的实验结果并不是很好,反而和过采样的结果相比略有下降,但是权当学习了

focal loss是何凯明大神在论文:Focal Loss for Dense Object Detection 中提出来用于目标检测任务的,但是本质也还是分类任务,所以就尝试把focal loss用在多分类任务上。

focal loss的形式如下:

FL(p{_{t}}) = -\alpha _{t}(1-p_{t})^{\gamma }log({p_{t}})

可以看到这里有两个超参,\alpha _{t}是用来平衡样本数量的,{\gamma }相当于惩罚项,用来控制难分样本的挖掘,{\gamma }=1的时候就是我们平常使用的交叉熵损失,但是当{\gamma }取值增大时,易分样本的损失就会变小,难分样本的损失则相对来说较大。在论文中\alpha _{t}=0.25,{\gamma }=2,这是作者认为的最佳参数,如果需要调整,当{\gamma }增大时,\alpha _{t}的值需要略微减小(\alpha _{t}的取值范围是(0,1])(In general \alpha _{t} should be decreased slightly as {\gamma } is increased)。这里的p_{t}就是predictions。

然后贴出笔者实现的tensorflow版本多分类focal loss:

focal_loss = tf.reduce_mean(- tf.reduce_sum(alpha * label * tf.pow(1-pred, gamma) * tf.log(pred), reduction_indices=[1]))

对于多分类,alpha的长度应该和类别数一致,直觉下应该是样本越多,alpha越小(多出好多超参...)。

如果有哪位大神通过focal loss得到较好的分类效果希望可以交流一下。

你可能感兴趣的:(python,focal,loss,tensorflow)