Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation

Learning Efficient Object Detection Models with Knowledge Distillation

  • 概述
  • Method
    • Knowledge Distillation for Classification with Imbalanced Classes
    • Knowledge Distillation for Regression with Teacher Bounds
    • Hint Learning with Feature Adaptation
  • 实验结果

之前博客整理的论文都是knowledge distillation及其变体,作为机器学习的一种方法的研究发展历程。从这篇博客开始,我将介绍其在CV领域的一些具体的用法。

本文是knowledge distillation在detection上成功应用的一个例子。

概述

knowledge distillation和hint learning在classification已经很成功了。然而对于detection,soft target不再是单一的类别概率输出,regression、proposal、less voluminous labels(较少的标签)都是在detection种使用distillation的挑战:
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第1张图片
本文应该是最先在detection种成功使用distillation的,主要idea有:

  1. end-to-end 使用distillation方式训练
  2. loss定义,a) weighted cross entropy loss for classification b) teacher bounded regression loss for knowledge distillation c) adaptation layers for hint learning
    Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第2张图片

Method

整体框架:
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第3张图片
总的loss有三大部分。 L H i n t L_{Hint} LHint是使用hint-based方法学习teacher经过backbone后feature的表达; L R P N L_{RPN} LRPN是学习teacher RPN部分的proposals,包括classification®ression两部分loss; L R C N L_{RCN} LRCN是学习teacher fast rcnn detector部分的prediction,也包括classification score®ression factor两部分
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第4张图片
loss按种类来分,有三种计算方式,分别用于classification、regression、feature adaption
下面依此介绍:

Knowledge Distillation for Classification with Imbalanced Classes

detection与classification的差异,detection在fg和bg的区分上容易误判,所以引入一个权重加大对bg类的惩罚 w 0 = 1.5 w_0=1.5 w0=1.5,其他 w i w_i wi都是1:
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第5张图片
另外,对于简单的数据集且是分类任务,需要设置temperature T,是输出分布更soft,拉近类别间的差距。
而对于detection这样一个本身就比较难的任务来说,很多类都有明显的预测误差,设置T=1时性能是最好的
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第6张图片

Knowledge Distillation for Regression with Teacher Bounds

Knowledge Distillation for Regression有一个麻烦就是 regression direction可能和GT相差甚远:
在这里插入图片描述
策略是,当student的 R s R_s Rs偏的比teacher还要离谱一些(margin m)时,加入一个loss惩罚
和teacher的表现相近时,就不push了,此项loss为0
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第7张图片
可以看到即使偏的离谱的时候,student学习的还是hard label。可能这个没有软分布,teacher输出可能误差也比较大,不如直接学习GT,只是和teacher表现差距太大时,多加一项loss

Hint Learning with Feature Adaptation

作者似乎L1 L2 loss 都进行了尝试:
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第8张图片
当hint layer 和 guided layer不匹配时,需要用一个adaption layer进行转换。hint and guided都是FC时就用FC,都是conv layer就用1x1conv来匹配。而且作者发现即便channels数一样时,一个额外的adaptation layer也有助于实现高效的知识迁移
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第9张图片

实验结果

可以看出来提升还是挺明显的,用VGGM做detection,效果对比未使用蒸馏提升了四个点。这个网络速度只占VGG16的1/3,虽然效果还是查了蛮多。
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第10张图片
Knowledge Distillation(8)——Learning Efficient Object Detection Models with Knowledge Distillation_第11张图片

你可能感兴趣的:(Knowledge,Distillation,知识蒸馏)