[目标检测知识蒸馏1] [NIPS17] Learning Efficient Object Detection Models with Knowledge Distillation

文章目录

    • 什么是知识蒸馏?
    • 目标检测中的知识蒸馏
  • [NIPS17] **Learning Efficient Object Detection Models with Knowledge Distillation**
    • Introduction:
    • Method
      • 分类任务中的类别不均衡现象
      • 回归任务
      • Hint learning with Feature Adaption
    • Experiment

什么是知识蒸馏?

  • 知识蒸馏是指从大模型(Teacher model)中学习到有用的知识来训练小模型(Student model),在保证不损失太多性能的情况下,进行模型压缩。
  • 最早是为了解决模型压缩(轻量化)问题。
  • 在蒸馏过程中,student model 学习到 teacher model 的泛化能力,保留了接近 teacher model 的性能。 在保留精度的同时,能压缩模型,提升速度。但只在分类任务上得到了印证,在更复杂的 object detection 上还有待探索。

目标检测中的知识蒸馏

  • 目标检测任务 label 信息量更大,根据 label 学到的模型更为复杂,压缩后损失更多;
  • 分类任务中,每个类别相对均衡,同等重要。而目标检测任务中,存在类别不平衡问题,背景类偏多;
  • 目标检测任务更为复杂,既有类别分类,也有位置回归的预测;
  • 现行的知识蒸馏主要针对同一域中数据进行蒸馏,对于跨域目标检测的任务而言,对知识的蒸馏有更高的要求。

[NIPS17] Learning Efficient Object Detection Models with Knowledge Distillation

Introduction:

主要是通过设置三个 loss 函数,分别对 backbone、cls head、reg head 进行蒸馏:

  • 对于 backbone: 使用 hint learning进行蒸馏,增加一个 adaptation layers,让 feature map 的维度匹配;
  • 对于分类任务:使用 weighted CE Loss 解决类别失衡严重问题;
  • 对于回归任务:除了原本的 smooth- ℓ 1 \ell_1 1 loss,增加 teacher bounded regression loss。
    [目标检测知识蒸馏1] [NIPS17] Learning Efficient Object Detection Models with Knowledge Distillation_第1张图片

Method

教师网络的知识提取分为三点:**中间层 Feature Maps 的 Hint;RPN/RCN 中分类层的 knowledge;以及RPN/RCN 中回归层的 knowlege。**具体如下:
L R C N = 1 N ∑ i L c l s R C N + λ 1 N ∑ j L r e g R C N L R P N = 1 M ∑ i L c l s R P N + λ 1 N ∑ j L r e g R P N L = L R P N + L R C N + γ L H i n t L_{RCN}=\frac{1}{N}\sum_iL_{cls}^{RCN}+\lambda \frac{1}{N}\sum_jL_{reg}^{RCN}\\ L_{RPN}=\frac{1}{M}\sum_iL_{cls}^{RPN}+\lambda \frac{1}{N}\sum_jL_{reg}^{RPN}\\ L=L_{RPN}+L_{RCN}+\gamma L_{Hint} LRCN=N1iLclsRCN+λN1jLregRCNLRPN=M1iLclsRPN+λN1jLregRPNL=LRPN+LRCN+γLHint

  • N N N M M M 分别是对应部分的batch-size大小, λ \lambda λ γ \gamma γ 是超参数(这里分别设定为 1 1 1 0.5 0.5 0.5);
  • L c l s L_{cls} Lcls 包括 hard target 和知识蒸馏中的 soft target;
  • L r e g L_{reg} Lreg 包括 smooth- ℓ 1 \ell_1 1 和新提出的 teacher bounded ℓ 2 \ell_2 2 regression loss;
  • L H i n t L_{Hint} LHint 为主干网络的损失。

分类任务中的类别不均衡现象

教师网络和学生网络的输出分别如下:
P t = softmax ( Z t T ) P s = softmax ( Z s T ) P_t=\text{softmax}(\frac{Z_t}{T})\\ P_s=\text{softmax}(\frac{Z_s}{T}) Pt=softmax(TZt)Ps=softmax(TZs)
学生网络的优化损失如下:
L c l s = μ L h a r d ( P s ,   y ) + ( 1 − μ ) L s o f t ( P s ,   P t ) L_{cls}=\mu L_{hard}(P_s,~y)+(1-\mu)L_{soft}(P_s,~P_t) Lcls=μLhard(Ps, y)+(1μ)Lsoft(Ps, Pt)

  • L h a r d L_{hard} Lhard 是用 gt 监督的 Cross Entropy
  • L s o f t L_{soft} Lsoft 是用教师网络的信息监督的 soft loss。

分类任务中, 分类错误只会来自 foreground categories。目标检测中的分类子任务,background and foreground categories 都会导致错分。

  • 对于分类损失中的 background 误分概率占比较高的情况,提出增大蒸馏交叉熵中背景类的权重来解决失衡问题。
    L s o f t ( P s ,   P t ) = − ∑ w c P t log P s L_{soft}(P_s,~P_t)=-\sum w_cP_t\text{log}P_s Lsoft(Ps, Pt)=wcPtlogPs

回归任务

对于回归结果的蒸馏,**regression direction 可能和 gt 相差较大:**由于回归的输出是无界的,教师网络的预测方向可能与 gt 的方向相反。因此,将教师的输出损失作为上界,当学生网络的输出损失大于上界时,计入该损失;否则不考虑该 loss。
L b ( R S ,   R t ,   y ) = { ∥ R s − y ∥ 2 2 ,   if  ∥ R s − y ∥ 2 2 + m > ∥ R t − y ∥ 2 2 0 ,   otherwise L r e g = L s m o o t h − ℓ 1 ( R S ,   y r e g ) + ν L b ( R s ,   R t ,   y r e g ) L_b(R_S,~R_t,~y)= \begin{cases} \|R_s-y\|^2_2,~&\text{if}~\|R_s-y\|^2_2+m>\|R_t-y\|^2_2\\ 0,~&\text{otherwise} \end{cases} \\ L_{reg}=L_{smooth-\ell_1}(R_S,~y_{reg})+\nu L_b(R_s,~R_t,~y_{reg}) Lb(RS, Rt, y)={Rsy22, 0, if Rsy22+m>Rty22otherwiseLreg=Lsmooth1(RS, yreg)+νLb(Rs, Rt, yreg)

  • m is a margin,权重 ν = 0.5 \nu=0.5 ν=0.5
  • y r e g y_{reg} yreg denotes the regression ground truth label,是 proposal 和 gt 之间的回归量;
  • R t R_t Rt R s R_s Rs 分别是 teacher 和 student 网络学出来的回归量;
  • L s m o o t h − ℓ 1 L_{smooth-\ell_1} Lsmooth1是普通的 smooth ℓ 1 \ell_1 1 回归 loss。

Hint learning with Feature Adaption

论文中证明,using the intermediate representation of the teacher as hint can help the training process and improve the final performance of the student.
L = L R P N + L R C N + γ L H i n t L=L_{RPN}+L_{RCN}+\gamma L_{Hint} L=LRPN+LRCN+γLHint
其中 L H i n t L_{Hint} LHint是学生网络 backbone 的loss:
L H i n t ( V ,   Z ) = ∥ V − Z ∥ 2 2 L H i n t ( V ,   Z ) = ∥ V − Z ∥ 1 2 L_{Hint}(V,~Z)=\|V-Z\|^2_2\\ L_{Hint}(V,~Z)=\|V-Z\|^2_1 LHint(V, Z)=VZ22LHint(V, Z)=VZ12
变量 V ,   Z V,~Z V, Z 分别是教师网络和学生网络的 feature map(全 feature imitation),需要加入 adaption layer 使得二者维度相同。

Experiment

[目标检测知识蒸馏1] [NIPS17] Learning Efficient Object Detection Models with Knowledge Distillation_第2张图片

你可能感兴趣的:(知识蒸馏,目标检测,深度学习,计算机视觉)