小白科研笔记:深入理解mmDetection框架——损失函数

1. 前言

这篇博客主要分析mmDetection框架中常见的损失函数,以及它们的具体实现。

2. 目标检测中常见的损失函数

2.1 平滑L1范数

平滑L1范数用于描述目标框回归中的损失函数:

Smooth _ L 1 ( x , x g t ) = { 0.5 ∥ x − x g t ∥ 2 2 / β , ∥ x − x g t ∥ 1 < β ∥ x − x g t ∥ 1 − 0.5 β , ∥ x − x g t ∥ 1 ≥ β \text{Smooth} \_L1(x,x_{gt}) = \left\{ \begin{aligned} & 0.5\Vert x-x_{gt} \Vert_2^2/\beta,& \Vert x-x_{gt} \Vert_1 < \beta\\ & \Vert x-x_{gt} \Vert_1 - 0.5\beta, & \Vert x-x_{gt} \Vert_1 \geq \beta\\ \end{aligned} \right. Smooth_L1(x,xgt)={0.5xxgt22/β,xxgt10.5β,xxgt1<βxxgt1β

平滑L1范数的优势可以参考这篇简书笔记。如果预测变量 x x x的误差很大,为了避免给网络反向传递一个非常大的梯度导致训练震荡,就使用L1范数;如果预测变量 x x x的误差很小,需要给网络反向传递一个比较小的梯度,就使用L2范数。这是平滑L1范数的优势所在。它在Pytorch下实现的代码如下所示:

def smooth_l1_loss(pred, target, beta=1.0, reduction='mean'):
    assert beta > 0
    assert pred.size() == target.size() and target.numel() > 0
    diff = torch.abs(pred - target)
    # 使用 torch.where 用于区分两种情况
    loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
                       diff - 0.5 * beta)
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    # loss 是个向量,这里对 loss 内各个元素做平均,得到标量 loss
    elif reduction_enum == 1:
        return loss.sum() / pred.numel() 
    elif reduction_enum == 2:
        return loss.sum()

平滑L1范数还能推广为加权平滑L1范数(每个样本点的权重不一样):

Loss = ∑ i = 1 N w i ∗ Smooth _ L 1 ( x i , x i , g t ) ;    ∑ i = 1 N w i = 1 \text{Loss}=\sum_{i=1}^N w_i *\text{Smooth}\_L1(x_i,x_{i,gt}); \,\, \sum_{i=1}^N w_i = 1 Loss=i=1NwiSmooth_L1(xi,xi,gt);i=1Nwi=1

它在Pytorch下实现的代码如下所示:

def weighted_smoothl1(pred, target, weight, beta=1.0, avg_factor=None):
    if avg_factor is None:
        avg_factor = torch.sum(weight > 0).float().item() + 1e-6
    loss = smooth_l1_loss(pred, target, beta, reduction='none')
    return torch.sum(loss * weight)[None] / avg_factor

2.2 交叉熵损失函数

交叉熵损失函数用于表示目标检测中目标分类问题。单分类情形下交叉熵损失函数表示为:

Cross_Entropy ( p i , y i ) = { − log ⁡ ( p i ) , y i = 1 − log ⁡ ( 1 − p i ) , y i = 0 \text{Cross\_Entropy}(p_i,y_i)=\left\{ \begin{aligned} &-\log(p_i),&y_i=1 \\ &-\log(1-p_i), &y_i=0\\ \end{aligned} \right. Cross_Entropy(pi,yi)={log(pi),log(1pi),yi=1yi=0

其中 p i p_i pi表示预测二分类概率,而 y i y_i yi表示真值。单分类情形下,它们都是标量。

多分类情形下(类别个数为 C C C)交叉熵损失函数表示为:

Multi_Cross_Entropy ( p i , y i ) = − y i ⋅ l o g ( p i ) \text{Multi\_Cross\_Entropy}(p_i,y_i)=-y_i \cdot log(p_i) Multi_Cross_Entropy(pi,yi)=yilog(pi)

其中 y i y_i yi是一个one-hot类型 C × 1 C\times 1 C×1的向量,表示为 y i = ( 0 , . . . , 0 , 1 , 0 , . . . , 0 ) T y_i=(0,...,0,1,0,...,0)^T yi=(0,...,0,1,0,...,0)T。如果 y i y_i yi属于第 k k k类,那么向量 y i y_i yi中只有第 k k k个位置是 1 1 1 p i p_i pi是网络模型预测和Softmax作用的结果,它是一个 C × 1 C\times 1 C×1的向量,表示为 p i = ( p i , 0 , . . . , p i , C ) T p_i=(p_{i,0},...,p_{i,C})^T pi=(pi,0,...,pi,C)T,指各个类别的概率。它满足归一化的条件: ∑ j = 1 C p i , j = 1 \sum_{j=1}^C p_{i,j} =1 j=1Cpi,j=1。上式中 “ ⋅ ” “\cdot” 表示向量之间各个元素的点乘。

此外,交叉熵损失函数可以推广为加权交叉熵损失函数(每个样本点的权重不一样):

Loss = ∑ i = 1 N w i ∗ Multi_Cross_Entropy ( p i , y i ) ;    ∑ i = 1 N w i = 1 \text{Loss}=\sum_{i=1}^N w_i*\text{Multi\_Cross\_Entropy}(p_i,y_i); \,\, \sum_{i=1}^N w_i = 1 Loss=i=1NwiMulti_Cross_Entropy(pi,yi);i=1Nwi=1

它在Pytorch下实现的代码如下所示:

def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
    if avg_factor is None:
        avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
    raw = F.cross_entropy(pred, label, reduction='none')
    if reduce:
        return torch.sum(raw * weight)[None] / avg_factor
    else:
        return raw * weight / avg_factor

上述代码中交叉熵计算的代码放在下面讨论:

def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
                  reduce=None, reduction='mean'):

    if size_average is not None or reduce is not None:
        reduction = _Reduction.legacy_get_string(size_average, reduce)
    # nll_loss 相当于是交叉熵损失函数中的负号
    # 原始数据需要先做 softmax 归一化再去 log,合在一起就是 log_softmax
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

简单地讲,nll_loss相当于是交叉熵损失函数中的负号,可以看看这篇知乎笔记。

2.3 Focal Loss

Focal loss是交叉熵损失函数的改进版,用于解决不同类别之间样本不均衡,以及难易样本差异大的情况。在我之前的博客讨论过。单分类情形下Focal Loss损失函数表示为:

Focal_loss ( p i , y i ) = { − α ( 1 − p i ) γ log ⁡ ( p i ) , y i = 1 − ( 1 − α ) p i γ log ⁡ ( 1 − p i ) , y i = 0 \text{Focal\_loss}(p_i,y_i)=\left\{ \begin{aligned} &-\alpha(1-p_i)^\gamma\log(p_i),&y_i=1 \\ &-(1-\alpha)p_i^\gamma\log(1-p_i), &y_i=0\\ \end{aligned} \right. Focal_loss(pi,yi)={α(1pi)γlog(pi),(1α)piγlog(1pi),yi=1yi=0

其中 p i p_i pi表示预测二分类概率,而 y i y_i yi表示真值。单分类情形下,它们都是标量。 α \alpha α表示正负样本之间的数量比例。 γ \gamma γ表示难易参数,会设置 γ > 1 \gamma>1 γ>1表示重视困难样本的损失误差。

多分类情形下(类别个数为 C C CFocal Loss损失函数表示为:

Multi_Focal_Loss ( p i , y i ) = − ( 1 − α ) ⋅ ( 1 − y i ) γ ⋅ y i ⋅ l o g ( p i ) \text{Multi\_Focal\_Loss}(p_i,y_i)=-(\textbf{1}-\alpha) \cdot (\textbf{1}-y_i)^\gamma \cdot y_i \cdot log(p_i) Multi_Focal_Loss(pi,yi)=(1α)(1yi)γyilog(pi)

其中 y i y_i yi是一个one-hot类型 C × 1 C\times 1 C×1的向量,表示为 y i = ( 0 , . . . , 0 , 1 , 0 , . . . , 0 ) T y_i=(0,...,0,1,0,...,0)^T yi=(0,...,0,1,0,...,0)T。如果 y i y_i yi属于第 k k k类,那么向量 y i y_i yi中只有第 k k k个位置是 1 1 1 p i p_i pi是网络模型预测和Softmax作用的结果,它是一个 C × 1 C\times 1 C×1的向量,表示为 p i = ( p i , 0 , . . . , p i , C ) T p_i=(p_{i,0},...,p_{i,C})^T pi=(pi,0,...,pi,C)T,指各个类别的概率。它满足归一化的条件: ∑ j = 1 C p i , j = 1 \sum_{j=1}^C p_{i,j} =1 j=1Cpi,j=1。上式中 “ ⋅ ” “\cdot” 表示向量之间各个元素的点乘。 1 \textbf{1} 1表示 C × 1 C\times 1 C×1的向量,填充元素的值都是 1 1 1 α = ( α 1 , . . . , α C ) T \alpha=(\alpha_1,...,\alpha_C)^T α=(α1,...,αC)T C × 1 C\times 1 C×1的向量,表示各个类别的占比。

从编程角度讲,Focal Loss相比于交叉熵损失函数,仅仅多了 ( 1 − α ) ⋅ ( 1 − y i ) γ (\textbf{1}-\alpha) \cdot (\textbf{1}-y_i)^\gamma (1α)(1yi)γ项而已。对于二分类的Focal Loss,它在Pytorch下实现的代码如下所示:

def sigmoid_focal_loss(pred,
                       target,
                       weight, # 表示各个样本点的权重
                       gamma=2.0,
                       alpha=0.25,
                       reduction='mean'):
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    # pt 和 weight 的计算方式很精妙,节省了 if 语句
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
    weight = weight * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * weight
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    elif reduction_enum == 2:
        return loss.sum()

其中函数binary_cross_entropy专指二分类问题的交叉熵损失函数。

此外,Focal loss损失函数可以推广为加权Focal loss损失函数(每个样本点的权重不一样):

Loss = ∑ i = 1 N w i ∗ Multi_Focal_Loss ( p i , y i ) ;    ∑ i = 1 N w i = 1 \text{Loss}=\sum_{i=1}^N w_i*\text{Multi\_Focal\_Loss}(p_i,y_i); \,\, \sum_{i=1}^N w_i = 1 Loss=i=1NwiMulti_Focal_Loss(pi,yi);i=1Nwi=1

它在Pytorch下实现的代码如下所示:

def weighted_sigmoid_focal_loss(pred,
                                target,
                                weight,
                                gamma=2.0,
                                alpha=0.25,
                                avg_factor=None,
                                num_classes=80):
    if avg_factor is None:
        avg_factor = torch.sum(weight > 0).float().item() / num_classes + 1e-6
    return sigmoid_focal_loss(
        pred, target, weight, gamma=gamma, alpha=alpha,
        reduction='sum')[None] / avg_factor

3. 结束语

至此,讨论了目标检测中常见的损失函数的数学表达式和编程实现。最后用一张图概括这篇博客所讨论的内容。

小白科研笔记:深入理解mmDetection框架——损失函数_第1张图片
图1:目标检测识别中常用的损失函数

你可能感兴趣的:(computer,vision论文代码分析)