import torch
import torch.nn as nn
import numpy as np
class FocalLoss(nn.Module):
# Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
super().__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
self.gamma = gamma
self.alpha = alpha
self.reduction = loss_fcn.reduction
self.loss_fcn.reduction = 'none' # required to apply FL to each element
def forward(self, pred, true):
print("pred:",pred.shape)
print("true:", true.shape)
loss = self.loss_fcn(pred, true)
print("loss:",loss.shape)
pred_prob = torch.sigmoid(pred) # prob from logits
print("pred_prob:", pred_prob.shape)
p_t = true * pred_prob + (1 - true) * (1 - pred_prob) # 计算概率
print("p_t:",p_t.shape)
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
print("alpha_factor:", alpha_factor.shape)
modulating_factor = (1.0 - p_t) ** self.gamma
print("modulating_factor:", modulating_factor.shape)
"""
loss *= alpha_factor * modulating_factor
等同于:
loss = loss * [true * self.alpha + (1 - true) * (1 - self.alpha)] * [(1.0 - p_t) ** self.gamma]
loss = true * self.alpha * [(1.0 - p_t) ** self.gamma] * loss + (1 - true) * (1 - self.alpha) * [(1.0 - p_t) ** self.gamma] * loss
"""
loss *= alpha_factor * modulating_factor
print("focal_loss:", loss.shape)
return loss.mean()
if __name__ == "__main__":
loss_fcn = nn.BCEWithLogitsLoss()
focal_loss = FocalLoss(loss_fcn, gamma=1.5, alpha=0.25)
pred = np.random.random((852,2))
true = np.random.random((852, 2))
pred = torch.tensor(pred)
true = torch.tensor(true)
loss = focal_loss(pred,true)
print(loss)
1、yolov5中如果开启focal_loss函数,则默认是分类损失、置信度损失都使用FocalLoss; 假设预测出852个目标框,则输出如下:
pred: torch.Size([852, 2])
true: torch.Size([852, 2])
loss: torch.Size([852, 2])
pred_prob: torch.Size([852, 2])
p_t: torch.Size([852, 2])
alpha_factor: torch.Size([852, 2])
modulating_factor: torch.Size([852, 2])
focal_loss: torch.Size([852, 2])
tensor(0.1520, dtype=torch.float64)
2、 当把yolov5修改为旋转目标检测时,会增加角度信息,会多一个角度loss,角度范围[0,180],则对应的输出为:
pred: torch.Size([852, 180])
true: torch.Size([852, 180])
loss: torch.Size([852, 180])
pred_prob: torch.Size([852, 180])
p_t: torch.Size([852, 180])
alpha_factor: torch.Size([852, 180])
modulating_factor: torch.Size([852, 180])
focal_loss: torch.Size([852, 180])
tensor(0.1536, dtype=torch.float64)
3、代码中的loss计算公式可以解析如下:
loss *= alpha_factor * modulating_factor 等同于: loss = loss * [true * self.alpha + (1 - true) * (1 - self.alpha)] * [(1.0 - p_t) ** self.gamma] 等同于: loss = true * self.alpha * [(1.0 - p_t) ** self.gamma] * loss + (1 - true) * (1 - self.alpha) * [(1.0 - p_t) ** self.gamma] * loss
1、以下是focalloss的数学公式,可与代码对照着看,以下两者都对:
2、另外,值得注意的是,大多数的讲解中,会有交叉熵和focalloss的公式对比,但是,在计算loss时,都会乘以真实标签,所以,在上述代码中会发现也乘了一个True,即真实的标签;