原文:Focal Loss 论文理解及公式推导 - AIUAI
题目: Focal Loss for Dense Object Detection - ICCV2017
作者: Tsung-Yi, Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollar
团队: FAIR
精度最高的目标检测器往往基于 RCNN 的 two-stage 方法,对候选目标位置再采用分类器处理. 而,one-stage 目标检测器是对所有可能的目标位置进行规则的(regular)、密集采样,更快速简单,但是精度还在追赶 two-stage 检测器. <论文所关注的问题于此.>
论文发现,密集检测器训练过程中,所遇到的极端前景背景类别不均衡(extreme foreground-background class imbalance)是核心原因.
对此,提出了 Focal Loss,通过修改标准的交叉熵损失函数,降低对能够很好分类样本的权重(down-weights the loss assigned to well-classified examples),解决类别不均衡问题.
Focal Loss 关注于在 hard samples 的稀疏子集进行训练,并避免在训练过程中大量的简单负样本淹没检测器.
Focal Loss 是动态缩放的交叉熵损失函数,随着对正确分类的置信增加,缩放因子(scaling factor) 衰退到 0. 如图:
Focal Loss 的缩放因子能够动态的调整训练过程中简单样本的权重,并让模型快速关注于困难样本(hard samples).
基于 Focal Loss 的 RetinaNet 的目标检测器表现.
Focal Loss 旨在解决 one-stage 目标检测器在训练过程中出现的极端前景背景类不均衡的问题(如,前景:背景 = 1:1000).
首先基于二值分类的交叉熵(cross entropy, CE) 引入 Focal Loss:
C E ( p , y ) = { − l o g ( p ) if y = 1 − l o g ( 1 − p ) otherwise CE(p, y) = \begin{cases} -log(p) &\text{if } y=1 \\ -log(1-p) &\text{otherwise } \end{cases} CE(p,y)={−log(p)−log(1−p)if y=1otherwise
其中, y ∈ { + 1 − 1 } y \in \lbrace +1 -1 \rbrace y∈{+1−1} 为 groundtruth 类别; p ∈ [ 0 , 1 ] p \in [0, 1] p∈[0,1] 是模型对于类别 y = 1 y=1 y=1 所得到的预测概率.
符号简介起见,定义 p t p_t pt:
p t = { p if y = 1 1 − p otherwise p_t = \begin{cases} p &\text{if } y=1 \\ 1-p &\text{otherwise } \end{cases} pt={p1−pif y=1otherwise
则, C E ( p , y ) = C E ( p t ) = − l o g ( p t ) CE(p, y) = CE(p_t) = -log(p_t) CE(p,y)=CE(pt)=−log(pt).
CE Loss 如图 Figure 1 中的上面的蓝色曲线所示. 其一个显著特点是,对于简单易分的样本( p t ≫ 0.5 p_t \gg 0.5 pt≫0.5),其 loss 也是一致对待. 当累加了大量简单样本的 loss 后,具有很小 loss 值的可能淹没稀少的类(rare class).
解决类别不均衡的一种常用方法是,对类别 +1 引入权重因子 α ∈ [ 0 , 1 ] \alpha \in [0, 1] α∈[0,1],对于类别 -1 引入权重 1 − α 1 - \alpha 1−α.
符号简介起见,定义 α t \alpha _t αt:
α t = { α if y = 1 1 − α otherwise \alpha_t = \begin{cases} \alpha &\text{if } y=1 \\ 1-\alpha &\text{otherwise } \end{cases} αt={α1−αif y=1otherwise
则, α \alpha α-balanced CE loss 为:
C E ( p t ) = − α t l o g ( p t ) CE(p_t) = -\alpha _t log(p_t) CE(pt)=−αtlog(pt)
虽然 α \alpha α 能够平衡 positive/negative 样本的重要性,但不能区分 easy/had 样本.
对此,Focal Loss 提出将损失函数降低 easy 样本的权重,并关注于对 hard negatives 样本的训练.
添加调制因子(modulating factor) ( 1 − p t ) γ (1 - p_t)^{\gamma} (1−pt)γ 到 CE loss,其中 γ ≥ 0 \gamma \ge 0 γ≥0 为可调的 focusing 参数.
Focal Loss 定义为:
F L ( p t ) = − ( 1 − p t ) γ l o g ( p t ) FL(p_t) = -(1 - p_t)^{\gamma} log(p_t) FL(pt)=−(1−pt)γlog(pt)
如图 Figure 1,给出了 γ ∈ [ 0 , 5 ] \gamma \in [0, 5] γ∈[0,5] 中几个值的可视化.
Focal Loss 的两个属性:
直观上,调制因子能够减少 easy 样本对于损失函数的贡献,并延伸了loss 值比较地的样本范围.
例如, γ = 0.2 \gamma = 0.2 γ=0.2 时,被分类为 p t = 0.9 p_t=0.9 pt=0.9 的样本,与 CE 相比,会减少 100x 倍;而且,被分类为 $p_t \approx 0.968 $ 的样本,与 CE 相比,会有少于 1000x 倍的 loss 值. 这就自然增加了将难分类样本的重要性(如 γ = 2 \gamma= 2 γ=2 且 p t ≤ 0.5 p_t \leq 0.5 pt≤0.5 时,难分类样本的 loss 值会增加 4x 倍.)
实际上,论文采用了 Focal Loss 的 α \alpha α -balanced 变形:
F L ( p t ) = − α t ( 1 − p t ) γ l o g ( p t ) FL(p_t) = -\alpha _t (1 - p_t)^{\gamma} log(p_t) FL(pt)=−αt(1−pt)γlog(pt)
Focal Loss 并不局限于具体的形式. 这里给出另一种例示.
假设 p = σ ( x ) = 1 1 + e − x p = \sigma(x) = \frac{1}{1 + e^{-x}} p=σ(x)=1+e−x1,
定义 p t p_t pt为(类似于前面对于 p t p_t pt 的定义):
p t = { p if y = 1 1 − p otherwise p_t = \begin{cases} p &\text{if } y=1 \\ 1-p &\text{otherwise } \end{cases} pt={p1−pif y=1otherwise
定义: x t = y x x_t = yx xt=yx,其中, y ∈ { + 1 , − 1 } y \in \lbrace +1, -1 \rbrace y∈{+1,−1} 是 groundtruth 类别.
则: p t = σ ( x t ) = 1 1 + e y x p_t = \sigma(x_t) = \frac{1}{1 + e^{yx}} pt=σ(xt)=1+eyx1
当 x t > 0 x_t > 0 xt>0 时,样本被正确分类,此时 p t > 0.5 p_t > 0.5 pt>0.5.
有:
d p t d x = − 1 ( 1 + e y x ) 2 ∗ y ∗ e y x = y ∗ p t ∗ ( 1 − p t ) = − y ∗ p t ∗ ( p t − 1 ) \frac{d p_t}{d x} = \frac{-1}{(1 + e^{yx})^2} * y * e^{yx} = y * p_t * (1 - p_t) = -y * p_t * (p_t - 1) dxdpt=(1+eyx)2−1∗y∗eyx=y∗pt∗(1−pt)=−y∗pt∗(pt−1)
对于交叉熵损失函数 C E ( p t ) = − l o g ( p t ) CE(p_t) = -log(p_t) CE(pt)=−log(pt),由 d l n x d x = 1 x \frac{d lnx}{d x} = \frac{1}{x} dxdlnx=x1,
d C E ( p t ) d x = d C E ( p t ) d p t ∗ d p t d x = ( − 1 p t ) ∗ ( − y ∗ p t ∗ ( p t − 1 ) ) = y ∗ ( p t − 1 ) \frac{d CE(p_t)}{d x} = \frac{d CE(p_t)}{d p_t} * \frac{d p_t}{d x} = (- \frac{1}{p_t}) * (-y*p_t*(p_t - 1)) = y*(p_t - 1) dxdCE(pt)=dptdCE(pt)∗dxdpt=(−pt1)∗(−y∗pt∗(pt−1))=y∗(pt−1)
对于 Focal Loss F L ( p t ) = − ( 1 − p t ) γ l o g ( p t ) FL(p_t) = -(1 - p_t)^{\gamma} log(p_t) FL(pt)=−(1−pt)γlog(pt),其中 γ \gamma γ 为常数.
d F L ( p t ) d x = d ( 1 − p t ) γ d x ∗ ( − l o g ( p t ) ) + ( 1 − p t ) γ ∗ d C E ( p t ) d x \frac{d FL(p_t)}{d x} = \frac{d (1-p_t)^{\gamma}}{d x} * (-log(p_t)) + (1-p_t)^{\gamma}*\frac{d CE(p_t)}{d x} dxdFL(pt)=dxd(1−pt)γ∗(−log(pt))+(1−pt)γ∗dxdCE(pt)
d F L ( p t ) d x = ( γ ∗ ( 1 − p t ) γ − 1 ∗ d ( 1 − p t ) d p t ) ∗ d p t d x ∗ ( − l o g ( p t ) ) + ( 1 − p t ) γ ∗ y ∗ ( p t − 1 ) \frac{d FL(p_t)}{d x} = (\gamma * (1-p_t)^{\gamma-1}*\frac{d (1-p_t)}{d p_t})*\frac{d p_t}{d x} * (-log(p_t)) + (1-p_t)^{\gamma}*y*(p_t -1) dxdFL(pt)=(γ∗(1−pt)γ−1∗dptd(1−pt))∗dxdpt∗(−log(pt))+(1−pt)γ∗y∗(pt−1)
d F L ( p t ) d x = ( γ ∗ ( 1 − p t ) γ − 1 ∗ ( − 1 ) ) ∗ ( − y ∗ p t ∗ ( p t − 1 ) ) ∗ ( − l o g ( p t ) ) + y ∗ ( 1 − p t ) γ ∗ ( p t − 1 ) \frac{d FL(p_t)}{d x} = (\gamma *(1- p_t)^{\gamma -1} * (-1))*(-y * p_t*(p_t -1))*(-log(p_t)) + y*(1-p_t)^{\gamma}*(p_t -1) dxdFL(pt)=(γ∗(1−pt)γ−1∗(−1))∗(−y∗pt∗(pt−1))∗(−log(pt))+y∗(1−pt)γ∗(pt−1)
d F L ( p t ) d x = γ ∗ ( 1 − p t ) γ ∗ y ∗ p t ∗ l o g ( p t ) + y ∗ ( 1 − p t ) γ ∗ ( p t − 1 ) \frac{d FL(p_t)}{d x} = \gamma *(1-p_t)^{\gamma}*y*p_t*log(p_t) + y*(1-p_t)^{\gamma}*(p_t - 1) dxdFL(pt)=γ∗(1−pt)γ∗y∗pt∗log(pt)+y∗(1−pt)γ∗(pt−1)
d F L ( p t ) d x = y ∗ ( 1 − p t ) γ ∗ ( γ ∗ p t ∗ l o g ( p t ) + ( p t − 1 ) ) \frac{d FL(p_t)}{d x} = y*(1-p_t)^{\gamma}*(\gamma * p_t *log(p_t) + (p_t - 1)) dxdFL(pt)=y∗(1−pt)γ∗(γ∗pt∗log(pt)+(pt−1))
再者,假设 p t ∗ = σ ( γ x t + β ) p_t^* = \sigma (\gamma x_t + \beta) pt∗=σ(γxt+β),则 F L ∗ ( p t ∗ ) = − l o g ( p t ∗ ) / γ FL^*(p_t^{*}) = -log(p_t^*)/ \gamma FL∗(pt∗)=−log(pt∗)/γ,其中 γ \gamma γ 为常数.
d F L ∗ ( p t ∗ ) d x = − 1 p t ∗ ∗ 1 γ ∗ d p t ∗ d ( γ x t + β ) ∗ d ( γ x t + β ) d x \frac{d FL^*(p_t^*)}{d x} = -\frac{1}{p_t^*}*\frac{1}{\gamma}*\frac{d p_t^*}{d (\gamma x_t + \beta)} * \frac{d( \gamma x_t + \beta)}{d x} dxdFL∗(pt∗)=−pt∗1∗γ1∗d(γxt+β)dpt∗∗dxd(γxt+β)
d F L ∗ ( p t ∗ ) d x = − 1 p t ∗ ∗ 1 γ ∗ ( − y ∗ p t ∗ ∗ ( p t ∗ − 1 ) ∗ γ ) = y ∗ ( p t ∗ − 1 ) \frac{d FL^*(p_t^*)}{d x} = -\frac{1}{p_t^*} * \frac{1}{\gamma} * (-y * p_t^* * (p_t^* - 1)*\gamma) = y*(p_t^* - 1) dxdFL∗(pt∗)=−pt∗1∗γ1∗(−y∗pt∗∗(pt∗−1)∗γ)=y∗(pt∗−1)
则, F L ∗ FL^* FL∗ 包含两个参数 γ \gamma γ 和 β \beta β,控制着 loss 曲线的陡度(steepness) 和移动(shift). 如 Figure 5.
C E CE CE 关于 x x x 的求导:
d C E d x = y ( p t − 1 ) \frac{d CE}{ dx} = y(p_t - 1) dxdCE=y(pt−1)
F L FL FL 关于 x x x 的求导:
d F L d x = y ( 1 − p t ) γ ( γ p t l o g ( p t ) + p t − 1 ) \frac{d FL}{d x} = y(1-p_t)^{\gamma} (\gamma p_t log(p_t) + p_t - 1) dxdFL=y(1−pt)γ(γptlog(pt)+pt−1)
F L ∗ FL^* FL∗ 关于 x x x 的求导:
d F L ∗ d x = y ( p t ∗ − 1 ) \frac{d FL^*}{d x} = y(p_t^* - 1) dxdFL∗=y(pt∗−1)
如图 Figure 6. 三种 loss 函数,对于high-confidence 的预测结果,其导数都趋近于 -1 或 0.
但,与 C E CE CE 不同的是, F L FL FL 和 F L ∗ FL^* FL∗ 的有效设置时,只要 x t > 0 x_t > 0 xt>0,二者的导数都是很小的.
Focal Loss 损失函数:
F L ( p t ) = − α ( 1 − p t ) γ l o g ( p t ) FL(p_t) = - \alpha (1 - p_t)^{\gamma} log(p_t) FL(pt)=−α(1−pt)γlog(pt)
其中:
p t = { p if y = 1 1 − p otherwise p_t = \begin{cases} p &\text{if } y=1 \\ 1-p &\text{otherwise } \end{cases} pt={p1−pif y=1otherwise
Softmax 函数:
p i = e x i ∑ k = 1 K e x k p_i = \frac{e^{x_i}}{\sum _{k=1}^K e^{x_k}} pi=∑k=1Kexkexi
其中, K K K 为类别数, x x x 是网络全连接层等的输出向量, x i x_i xi 是向量的第 i i i 个元素值.
则 F L FL FL 关于 x x x 求导:
d F L d x i = d F L d p i ∗ d p i d x i \frac{d FL}{d x_i} = \frac{d FL}{d p_i} * \frac{d p_i}{d x_i} dxidFL=dpidFL∗dxidpi
而,
d F L d p t = − α ( d ( 1 − p t ) γ d p t ∗ l o g ( p t ) + ( 1 − p t ) γ ∗ d ( l o g ( p t ) ) d p t ) \frac{d FL}{d p_t} = - \alpha (\frac{d (1-p_t)^{\gamma}}{d p_t} * log(p_t) + (1-p_t)^{\gamma} * \frac{d (log(p_t))}{d p_t}) dptdFL=−α(dptd(1−pt)γ∗log(pt)+(1−pt)γ∗dptd(log(pt)))
d F L d p t = − α ( − γ ∗ ( 1 − p t ) γ − 1 ∗ l o g ( p t ) + ( 1 − p t ) γ ∗ 1 p t ) \frac{d FL}{d p_t} = - \alpha (- \gamma * (1-p_t)^{\gamma - 1} * log(p_t) + (1-p_t)^{\gamma} * \frac{1}{p_t}) dptdFL=−α(−γ∗(1−pt)γ−1∗log(pt)+(1−pt)γ∗pt1)
Softmax 函数关于 x 的求导为:
d p i d x i = d e x i ∑ k = 1 K e x k d x i \frac{d p_i}{d x_i} = \frac{d \frac{e^{x_i}}{\sum _{k=1}^K e^{x_k}}}{d x_i} dxidpi=dxid∑k=1Kexkexi
d p i d x i = d ( e x i ) d x i ∗ ∑ k = 1 K e x k − e x i ∗ d ( ∑ k = 1 K e x k ) d x i ( ∑ k = 1 K e x k ) 2 \frac{d p_i}{d x_i} = \frac{\frac{d(e^{x_i})}{d x_i}*\sum _{k=1}^K e^{x_k} - e^{x_i}*\frac{d(\sum _{k=1}^K e^{x_k})}{dx_i}}{(\sum _{k=1}^K e^{x_k})^2} dxidpi=(∑k=1Kexk)2dxid(exi)∗∑k=1Kexk−exi∗dxid(∑k=1Kexk)
当 i = j i=j i=j 时,
d p i d x i = e x i ∗ ∑ k = 1 K e x k − e x i ∗ e x i ( ∑ k = 1 K e x k ) 2 \frac{d p_i}{d x_i} = \frac{e^{x_i}*\sum _{k=1}^K e^{x_k} - e^{x_i}*e^{x_i}}{(\sum _{k=1}^K e^{x_k})^2} dxidpi=(∑k=1Kexk)2exi∗∑k=1Kexk−exi∗exi
d p i d x i = e x i ∑ k = 1 K e x k − e x i ∑ k = 1 K e x k ∗ e x i ∑ k = 1 K e x k \frac{d p_i}{d x_i} = \frac{e^{x_i}}{\sum _{k=1}^K e^{x_k}} - \frac{e^{x_i}}{\sum _{k=1}^K e^{x_k}}* \frac{e^{x_i}}{\sum _{k=1}^K e^{x_k}} dxidpi=∑k=1Kexkexi−∑k=1Kexkexi∗∑k=1Kexkexi
d p i d x i = p i − p i ∗ p i = p i ( 1 − p i ) \frac{d p_i}{d x_i} = p_i - p_i * p_i = p_i(1 - p_i) dxidpi=pi−pi∗pi=pi(1−pi)
当 i ≠ j i \neq j i̸=j 时,
d p i d x i = 0 − e x i ∗ e x j ( ∑ k = 1 K e x k ) 2 \frac{d p_i}{d x_i} = \frac{0 - e^{x_i}*e^{x_j}}{(\sum _{k=1}^K e^{x_k})^2} dxidpi=(∑k=1Kexk)20−exi∗exj
d p i d x i = − e x i ∑ k = 1 K e x k ∗ e x j ∑ k = 1 K e x k \frac{d p_i}{d x_i} = - \frac{e^{x_i}}{\sum _{k=1}^K e^{x_k}}* \frac{e^{x_j}}{\sum _{k=1}^K e^{x_k}} dxidpi=−∑k=1Kexkexi∗∑k=1Kexkexj
d p i d x i = − p i ∗ p j \frac{d p_i}{d x_i} = -p_i * p_j dxidpi=−pi∗pj
Softmax 的函数求导即为:
d p i d x i = { p i ( 1 − p i ) if i = j − p i ∗ p j if i ≠ j \frac{d p_i}{d x_i} = \begin{cases} p_i(1-p_i) &\text{if } i=j \\ -p_i*p_j &\text{if } i \neq j \end{cases} dxidpi={pi(1−pi)−pi∗pjif i=jif i̸=j
故:
$$
\frac{d FL}{d x_i} = \begin{cases}
d F L d x i = { α ( − γ ∗ ( 1 − p i ) γ − 1 ∗ l o g ( p i ) p i + ( 1 − p i ) γ ) ∗ ( p i − 1 ) if i = j α ( − γ ∗ ( 1 − p i ) γ − 1 ∗ l o g ( p i ) p i + ( 1 − p i ) γ ) ∗ p j if i ≠ j \frac{d FL}{d x_i} = \begin{cases} \alpha (- \gamma * (1-p_i)^{\gamma - 1} * log(p_i)p_i + (1-p_i)^{\gamma}) * (p_i-1) &\text{if } i=j \\ \alpha (- \gamma * (1-p_i)^{\gamma - 1} * log(p_i)p_i + (1-p_i)^{\gamma}) * p_j &\text{if } i \neq j \end{cases} dxidFL={α(−γ∗(1−pi)γ−1∗log(pi)pi+(1−pi)γ)∗(pi−1)α(−γ∗(1−pi)γ−1∗log(pi)pi+(1−pi)γ)∗pjif i=jif i̸=j
FocalLoss-PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = torch.Tensor([gamma])
self.size_average = size_average
if isinstance(alpha, (float, int, long)):
if self.alpha > 1:
raise ValueError('Not supported value, alpha should be small than 1.0')
else:
self.alpha = torch.Tensor([alpha, 1.0 - alpha])
if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
self.alpha /= torch.sum(self.alpha)
def forward(self, input, target):
if input.dim() > 2:
input = input.view(input.size(0), input.size(1), -1) # [N,C,H,W]->[N,C,H*W] ([N,C,D,H,W]->[N,C,D*H*W])
# target
# [N,1,D,H,W] ->[N*D*H*W,1]
if self.alpha.device != input.device:
self.alpha = torch.tensor(self.alpha, device=input.device)
target = target.view(-1, 1)
logpt = torch.log(input + 1e-10)
logpt = logpt.gather(1, target)
logpt = logpt.view(-1, 1)
pt = torch.exp(logpt)
alpha = self.alpha.gather(0, target.view(-1))
gamma = self.gamma
if not self.gamma.device == input.device:
gamma = torch.tensor(self.gamma, device=input.device)
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
if self.size_average:
loss = loss.mean()
else:
loss = loss.sum()
return loss
keras-focal-loss
基于 Keras 和 TensorFlow 后端实现的 Binary Focal Loss 和 Categorical/Multiclass Focal Loss.
主要设计两个参数:alpha
和 gamma
.
用法:
model.compile(optimizer='adam', loss=categorical_focal_loss(gamma=2.0, alpha=0.25), metrics=['accuracy'])
实现:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Oct 19 08:20:58 2018
@OS: Ubuntu 18.04
@IDE: Spyder3
@author: Aldi Faizal Dimara (Steam ID: phenomos)
"""
import keras.backend as K
import tensorflow as tf
def categorical_focal_loss(gamma=2.0, alpha=0.25):
"""
Implementation of Focal Loss from the paper in multiclass classification
Formula:
loss = -alpha*((1-p)^gamma)*log(p)
Parameters:
alpha -- the same as wighting factor in balanced cross entropy
gamma -- focusing parameter for modulating factor (1-p)
Default value:
gamma -- 2.0 as mentioned in the paper
alpha -- 0.25 as mentioned in the paper
"""
def focal_loss(y_true, y_pred):
# Define epsilon so that the backpropagation will not result in NaN
# for 0 divisor case
epsilon = K.epsilon()
# Add the epsilon to prediction value
#y_pred = y_pred + epsilon
# Clip the prediction value
y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
# Calculate cross entropy
cross_entropy = -y_true*K.log(y_pred)
# Calculate weight that consists of modulating factor and weighting factor
weight = alpha * y_true * K.pow((1-y_pred), gamma)
# Calculate focal loss
loss = weight * cross_entropy
# Sum the losses in mini_batch
loss = K.sum(loss, axis=1)
return loss
return focal_loss
def binary_focal_loss(gamma=2.0, alpha=0.25):
"""
Implementation of Focal Loss from the paper in multiclass classification
Formula:
loss = -alpha_t*((1-p_t)^gamma)*log(p_t)
p_t = y_pred, if y_true = 1
p_t = 1-y_pred, otherwise
alpha_t = alpha, if y_true=1
alpha_t = 1-alpha, otherwise
cross_entropy = -log(p_t)
Parameters:
alpha -- the same as wighting factor in balanced cross entropy
gamma -- focusing parameter for modulating factor (1-p)
Default value:
gamma -- 2.0 as mentioned in the paper
alpha -- 0.25 as mentioned in the paper
"""
def focal_loss(y_true, y_pred):
# Define epsilon so that the backpropagation will not result in NaN
# for 0 divisor case
epsilon = K.epsilon()
# Add the epsilon to prediction value
#y_pred = y_pred + epsilon
# Clip the prediciton value
y_pred = K.clip(y_pred, epsilon, 1.0-epsilon)
# Calculate p_t
p_t = tf.where(K.equal(y_true, 1), y_pred, 1-y_pred)
# Calculate alpha_t
alpha_factor = K.ones_like(y_true)*alpha
alpha_t = tf.where(K.equal(y_true, 1), alpha_factor, 1-alpha_factor)
# Calculate cross entropy
cross_entropy = -K.log(p_t)
weight = alpha_t * K.pow((1-p_t), gamma)
# Calculate focal loss
loss = weight * cross_entropy
# Sum the losses in mini_batch
loss = K.sum(loss, axis=1)
return loss
return focal_loss
[1] - Focal Loss 的前向与后向公式推导