yolov5 focal_loss源码解析

 一、以下为yolov5损失函数的源代码:

import torch
import torch.nn as nn
import numpy as np

class FocalLoss(nn.Module):
    # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
    def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
        super().__init__()
        self.loss_fcn = loss_fcn  # must be nn.BCEWithLogitsLoss()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = 'none'  # required to apply FL to each element

    def forward(self, pred, true):
        print("pred:",pred.shape)
        print("true:", true.shape)
        loss = self.loss_fcn(pred, true)
        print("loss:",loss.shape)
        pred_prob = torch.sigmoid(pred)  # prob from logits
        print("pred_prob:", pred_prob.shape)
        p_t = true * pred_prob + (1 - true) * (1 - pred_prob)  # 计算概率
        print("p_t:",p_t.shape)
        alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
        print("alpha_factor:", alpha_factor.shape)
        modulating_factor = (1.0 - p_t) ** self.gamma
        print("modulating_factor:", modulating_factor.shape)

        """
        loss *= alpha_factor * modulating_factor
        等同于:
        loss = loss * [true * self.alpha + (1 - true) * (1 - self.alpha)] * [(1.0 - p_t) ** self.gamma]
        loss = true * self.alpha * [(1.0 - p_t) ** self.gamma] * loss  +   (1 - true) * (1 - self.alpha) *  [(1.0 - p_t) ** self.gamma] * loss 
        """

        loss *= alpha_factor * modulating_factor
        print("focal_loss:", loss.shape)

        return loss.mean()

if __name__ == "__main__":
    loss_fcn = nn.BCEWithLogitsLoss()
    focal_loss = FocalLoss(loss_fcn, gamma=1.5, alpha=0.25)
    pred = np.random.random((852,2))
    true = np.random.random((852, 2))
    pred = torch.tensor(pred)
    true = torch.tensor(true)
    loss = focal_loss(pred,true)
    print(loss)

1、yolov5中如果开启focal_loss函数,则默认是分类损失、置信度损失都使用FocalLoss; 假设预测出852个目标框,则输出如下:

pred: torch.Size([852, 2])
true: torch.Size([852, 2])
loss: torch.Size([852, 2])
pred_prob: torch.Size([852, 2])
p_t: torch.Size([852, 2])
alpha_factor: torch.Size([852, 2])
modulating_factor: torch.Size([852, 2])
focal_loss: torch.Size([852, 2])
tensor(0.1520, dtype=torch.float64)

2、 当把yolov5修改为旋转目标检测时,会增加角度信息,会多一个角度loss,角度范围[0,180],则对应的输出为:

pred: torch.Size([852, 180])
true: torch.Size([852, 180])
loss: torch.Size([852, 180])
pred_prob: torch.Size([852, 180])
p_t: torch.Size([852, 180])
alpha_factor: torch.Size([852, 180])
modulating_factor: torch.Size([852, 180])
focal_loss: torch.Size([852, 180])
tensor(0.1536, dtype=torch.float64)

3、代码中的loss计算公式可以解析如下:

loss *= alpha_factor * modulating_factor
等同于:
loss = loss * [true * self.alpha + (1 - true) * (1 - self.alpha)] * [(1.0 - p_t) ** self.gamma]
等同于:
loss = true * self.alpha * [(1.0 - p_t) ** self.gamma] * loss  +   (1 - true) * (1 - self.alpha) *  [(1.0 - p_t) ** self.gamma] * loss 


二、FocalLoss数学公式解析:

1、以下是focalloss的数学公式,可与代码对照着看,以下两者都对:

2、另外,值得注意的是,大多数的讲解中,会有交叉熵和focalloss的公式对比,但是,在计算loss时,都会乘以真实标签,所以,在上述代码中会发现也乘了一个True,即真实的标签;

 

你可能感兴趣的:(数据处理及格式装换,pytorch\keras,算法研读,python,人工智能,开发语言)