Focal loss 知识蒸馏 目标检测 ResNet 特征金字塔

SOTA:state of the art

指在特定任务中目前表现最好的方法或模型

有了模型之后,我们需要通过定义损失函数来判断模型在样本上的表现

交叉熵loss

Focal loss 知识蒸馏 目标检测 ResNet 特征金字塔_第1张图片

Focal loss

Focal loss 知识蒸馏 目标检测 ResNet 特征金字塔_第2张图片

 

import torch
from torch.nn import functional as F

def sigmoid_focal_loss(
    inputs:torch.Tensor,
    targets:torch.Tensor,
    alpha:float=-1,
    gamma:float=2,#调节因子的指数
    reduction:str='none',#默认为none,传入mean的话outputs会被求平均,传入sum的话会被求和
)->torch.Tensor:
    """
    用于RetinaNet的loss函数
    """
    inputs=inputs.float(),
    targets=targets.float(),
    p=torch.sigmoid(inputs),
    ce_loss=F.binary_cross_entropy_with_logits(inputs,targets,reduction="none")
    p_t=p*targets+(1-p)*(1-targets)
    loss=ce_loss*((1-p_t)**gamma)
    if alpha>=0:
        alpha_t=alpha*targets+(1-alpha)*(1-targets)
        loss=alpha_t*loss
    if reduction=="mean":
        loss=loss.mean()
    elif reduction=="sum":
        loss=loss.sum()
return loss

    
    

RPN (Region Proposal Network) 用于生成候选区域(Region Proposal)

BasicBlock

shortcut 路径大致可以分成 2 种,取决于残差路径是否改变了feature map数量和尺寸。

Focal loss 知识蒸馏 目标检测 ResNet 特征金字塔_第3张图片

Focal loss 知识蒸馏 目标检测 ResNet 特征金字塔_第4张图片

特征金字塔

Focal loss 知识蒸馏 目标检测 ResNet 特征金字塔_第5张图片

你可能感兴趣的:(目标检测,深度学习,pytorch,机器学习)