Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏

Paper地址:https://arxiv.org/abs/2111.11837

GitHub链接:https://github.com/yzd-v/FGD

方法

Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏_第1张图片

FGKD(Focal and Global Knowledge Distillation)通过Focal distillation与Global distillation的结合,兼顾了Instance-level信息、Spatial/Channel Attention以及全局相关性信息。

首先定义前背景分离Mask、Attention等,然后基于Feature map计算Focal distillation(由Feature loss与Attention loss构成),具体如下:

  • 引入Binary mask分离前背景,其中r表示gt-box区域:

  • 设置Scale mask以平衡前背景Loss,其中Hr与Wr表示gt-box的高与宽:

Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏_第2张图片

  • 通过Reduced mean计算获得Spatial attention与Channel attention,并进一步通过Softmax计算获得Attention mask:

Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏_第3张图片

  •  然后Feature loss计算如下(基于Teacher与Student的特征输出,通常是Neck特征):

Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏_第4张图片

  • 然后计算Attention loss,并最终确定Focal distillation loss,其中l表示L1 loss:

def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
    loss_mse = nn.MSELoss(reduction='sum')
        
    Mask_fg = Mask_fg.unsqueeze(dim=1)
    Mask_bg = Mask_bg.unsqueeze(dim=1)
    C_t = C_t.unsqueeze(dim=-1)
    C_t = C_t.unsqueeze(dim=-1)
    S_t = S_t.unsqueeze(dim=1)

    fea_t= torch.mul(preds_T, torch.sqrt(S_t))
    fea_t = torch.mul(fea_t, torch.sqrt(C_t))
    fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
    bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))

    fea_s = torch.mul(preds_S, torch.sqrt(S_t))
    fea_s = torch.mul(fea_s, torch.sqrt(C_t))
    fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
    bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))

    fg_loss = loss_mse(fg_fea_s, fg_fea_t) / len(Mask_fg)
    bg_loss = loss_mse(bg_fea_s, bg_fea_t) / len(Mask_bg)
    return fg_loss, bg_loss

Global distillation通过提取不同像素之间的相关性,以实现Context信息的迁移,其中R(F)表示GcBlock的特征转换操作:

Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏_第5张图片

Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏_第6张图片

def get_rela_loss(self, preds_S, preds_T):
    loss_mse = nn.MSELoss(reduction='sum')

    context_s = self.spatial_pool(preds_S, 0)
    context_t = self.spatial_pool(preds_T, 1)
    out_s = preds_S
    out_t = preds_T

    channel_add_s = self.channel_add_conv_s(context_s)
    out_s = out_s + channel_add_s
    channel_add_t = self.channel_add_conv_t(context_t)
    out_t = out_t + channel_add_t
    rela_loss = loss_mse(out_s, out_t) / len(out_s)  
    return rela_loss

实验结果

文章将FGKD应用于不同的目标检测器,并分析了不同Loss、Focal distillation、Global distillation以及温度系数T的敏感度(具体参考文章实验部分)。不同目标检测器上的实验结果如下:

Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏_第7张图片

Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏_第8张图片

你可能感兴趣的:(知识蒸馏,计算机视觉,人工智能,深度学习,计算机视觉,人工智能,知识蒸馏,目标检测)