主要是通过设置三个 loss 函数,分别对 backbone、cls head、reg head 进行蒸馏:
教师网络的知识提取分为三点:**中间层 Feature Maps 的 Hint;RPN/RCN 中分类层的 knowledge;以及RPN/RCN 中回归层的 knowlege。**具体如下:
L R C N = 1 N ∑ i L c l s R C N + λ 1 N ∑ j L r e g R C N L R P N = 1 M ∑ i L c l s R P N + λ 1 N ∑ j L r e g R P N L = L R P N + L R C N + γ L H i n t L_{RCN}=\frac{1}{N}\sum_iL_{cls}^{RCN}+\lambda \frac{1}{N}\sum_jL_{reg}^{RCN}\\ L_{RPN}=\frac{1}{M}\sum_iL_{cls}^{RPN}+\lambda \frac{1}{N}\sum_jL_{reg}^{RPN}\\ L=L_{RPN}+L_{RCN}+\gamma L_{Hint} LRCN=N1i∑LclsRCN+λN1j∑LregRCNLRPN=M1i∑LclsRPN+λN1j∑LregRPNL=LRPN+LRCN+γLHint
教师网络和学生网络的输出分别如下:
P t = softmax ( Z t T ) P s = softmax ( Z s T ) P_t=\text{softmax}(\frac{Z_t}{T})\\ P_s=\text{softmax}(\frac{Z_s}{T}) Pt=softmax(TZt)Ps=softmax(TZs)
学生网络的优化损失如下:
L c l s = μ L h a r d ( P s , y ) + ( 1 − μ ) L s o f t ( P s , P t ) L_{cls}=\mu L_{hard}(P_s,~y)+(1-\mu)L_{soft}(P_s,~P_t) Lcls=μLhard(Ps, y)+(1−μ)Lsoft(Ps, Pt)
分类任务中, 分类错误只会来自 foreground categories。目标检测中的分类子任务,background and foreground categories 都会导致错分。
对于回归结果的蒸馏,**regression direction 可能和 gt 相差较大:**由于回归的输出是无界的,教师网络的预测方向可能与 gt 的方向相反。因此,将教师的输出损失作为上界,当学生网络的输出损失大于上界时,计入该损失;否则不考虑该 loss。
L b ( R S , R t , y ) = { ∥ R s − y ∥ 2 2 , if ∥ R s − y ∥ 2 2 + m > ∥ R t − y ∥ 2 2 0 , otherwise L r e g = L s m o o t h − ℓ 1 ( R S , y r e g ) + ν L b ( R s , R t , y r e g ) L_b(R_S,~R_t,~y)= \begin{cases} \|R_s-y\|^2_2,~&\text{if}~\|R_s-y\|^2_2+m>\|R_t-y\|^2_2\\ 0,~&\text{otherwise} \end{cases} \\ L_{reg}=L_{smooth-\ell_1}(R_S,~y_{reg})+\nu L_b(R_s,~R_t,~y_{reg}) Lb(RS, Rt, y)={∥Rs−y∥22, 0, if ∥Rs−y∥22+m>∥Rt−y∥22otherwiseLreg=Lsmooth−ℓ1(RS, yreg)+νLb(Rs, Rt, yreg)
论文中证明,using the intermediate representation of the teacher as hint can help the training process and improve the final performance of the student.
L = L R P N + L R C N + γ L H i n t L=L_{RPN}+L_{RCN}+\gamma L_{Hint} L=LRPN+LRCN+γLHint
其中 L H i n t L_{Hint} LHint是学生网络 backbone 的loss:
L H i n t ( V , Z ) = ∥ V − Z ∥ 2 2 L H i n t ( V , Z ) = ∥ V − Z ∥ 1 2 L_{Hint}(V,~Z)=\|V-Z\|^2_2\\ L_{Hint}(V,~Z)=\|V-Z\|^2_1 LHint(V, Z)=∥V−Z∥22LHint(V, Z)=∥V−Z∥12
变量 V , Z V,~Z V, Z 分别是教师网络和学生网络的 feature map(全 feature imitation),需要加入 adaption layer 使得二者维度相同。