Similarity-Preserving Knowledge Distillation论文阅读

今天分享一篇2020年ICCV关于知识蒸馏的论文,论文地址点这里

一. 介绍

知识蒸馏是一种用于监督“学生”神经网络训练的通用技术,它通过捕获和转移训练过的“教师”网络的知识来实现。虽然最初的动机是为了资源高效深度学习的神经网络压缩任务,但知识蒸馏已经在特权学习、对抗性防御[25]和噪声数据学习[19]等领域找到了更广泛的应用。知识蒸馏在概念上很简单:它通过额外的蒸馏损失来指导学生网络的训练,从而鼓励学生模仿教师网络的某些方面。直观地看,经过训练的教师网络比单独的数据监督(如标注的课堂标签)提供了更丰富的监督信号。
在本文中,我们提出了一种新的知识蒸馏形式,其灵感来自于观察到的语义相似的输入往往会在训练过的神经网络中引发相似的激活模式。保持相似的知识蒸馏指导学生网络的训练,这样,在受过训练的教师网络中产生相似(不相似)激活的输入对在学生网络中产生相似(不相似)激活。下入显示了整个过程。给定一个b个图像的输入小批,我们从输出激活映射计算成对的相似矩阵。b × b矩阵编码了网络激活的相似性,就像小批中的图像所诱发的那样。我们的蒸馏损失是由学生和老师产生的成对相似矩阵定义的。
Similarity-Preserving Knowledge Distillation论文阅读_第1张图片

二. 方法

在传统的知识蒸馏中,知识是以软化的类的分数的形式进行编码和传递的。训练学生的全部损失由以下部分构成:
L = ( 1 − α ) L C E ( y , σ ( z S ) ) + 2 α T 2 L C E ( σ ( z S T ) , σ ( z T T ) ) \mathcal{L}=(1-\alpha) \mathcal{L}_{\mathrm{CE}}\left(\mathbf{y}, \sigma\left(\mathbf{z}_S\right)\right)+2 \alpha T^2 \mathcal{L}_{\mathrm{CE}}\left(\sigma\left(\frac{\mathbf{z}_S}{T}\right), \sigma\left(\frac{\mathbf{z}_T}{T}\right)\right) L=(1α)LCE(y,σ(zS))+2αT2LCE(σ(TzS),σ(TzT))
其中 L C E ( ⋅ , ⋅ ) \mathcal{L}_{CE}(\cdot,\cdot) LCE(,)表示为交叉熵损失函数, σ ( ⋅ ) \sigma(\cdot) σ()表示为softmax函数, z S \mathbf{z}_S zS z T \mathbf{z}_T zT表示模型的回归值, T T T为温度超参数。
回顾介绍,语义相似的输入在经过训练的神经网络中往往会引发相似的激活模式。激活中的相关性是否可以编码有用的教师知识,从而传递给学生?我们的假设是,如果两个输入在教师网络中产生高度相似的激活,那么引导学生网络走向一个也会导致两个输入在学生网络中产生高度相似的激活的配置是有益的。相反,如果两个输入在老师体内产生不同的激活,我们希望这些输入在学生体内也产生不同的激活。
给定一个mini-batch的输入,我们使用 A T ( l ) ∈ R b × c × h × w A_T^{(l)} \in \mathbf{R}^{b \times c \times h \times w} AT(l)Rb×c×h×w表示为老师网络 T T T在第 l l l层的激活输出,其中 b b b表示为batch size, c c c为输出通道, h h h w w w为空间的维度。同理,对于学生网络 S S S来说,我们也可以使用 A S ( l ′ ) ∈ R b × c ′ × h ′ × w ′ A_S^{\left(l^{\prime}\right)} \in \mathbf{R}^{b \times c^{\prime} \times h^{\prime} \times w^{\prime}} AS(l)Rb×c×h×w表示为第 l ′ l' l层的激活输出。这里 c c c不一定等于 c ′ c' c,同理对于 h , w h,w h,w也是。如果两个网络的深度相同,那么层 l l l l ′ l' l的深度也是一样的,如果深度不同,那么我们就使用同一个块的最后一层。为了引导学生像老师网络中的激活靠近,我们定义了一个蒸馏损失。首先,我们使用如下表示:
G ~ T ( l ) = Q T ( l ) ⋅ Q T ( l ) ⊤ ; G T [ i , i ] ( l ) = G ~ T [ i , i ] ( l ) / ∥ G ~ T [ i , i ] ( l ) ∥ 2 \tilde{G}_T^{(l)}=Q_T^{(l)} \cdot Q_T^{(l) \top} ; \quad G_{T[i, i]}^{(l)}=\tilde{G}_{T[i, i]}^{(l)} /\left\|\tilde{G}_{T[i, i]}^{(l)}\right\|_2 G~T(l)=QT(l)QT(l);GT[i,i](l)=G~T[i,i](l)/ G~T[i,i](l) 2
其中 Q T ( l ) ∈ R b × c h w Q_T^{(l)} \in \mathbf{R}^{b \times c h w} QT(l)Rb×chw是对 A T ( l ) A^{(l)}_T AT(l)的维度转换,因此 G ~ T ( l ) \tilde{G}_T^{(l)} G~T(l)是一个 b × b b \times b b×b的矩阵。可以发现,对于这个矩阵的第 i i i j j j列表示为第 i i i张图片和第 j j j张图片的相似度。之后使用L2正则化对每一行进行处理,同理,对于学生网络来说也是如此:
G ~ S ( l ) = Q S ( l ) ⋅ Q S ( l ) ⊤ ; G S [ i , i ] ( l ) = G ~ S [ i , i ] ( l ) / ∥ G ~ S [ i , i ] ( l ) ∥ 2 \tilde{G}_S^{(l)}=Q_S^{(l)} \cdot Q_S^{(l) \top} ; \quad G_{S[i, i]}^{(l)}=\tilde{G}_{S[i, i]}^{(l)} /\left\|\tilde{G}_{S[i, i]}^{(l)}\right\|_2 G~S(l)=QS(l)QS(l);GS[i,i](l)=G~S[i,i](l)/ G~S[i,i](l) 2
之后我们就可以定义similarity-preserving knowledge loss:
L S P ( G T , G S ) = 1 b 2 ∑ ( l , l ′ ) ∈ I ∥ G T ( l ) − G S ( l ′ ) ∥ F 2 \mathcal{L}_{\mathrm{SP}}\left(G_T, G_S\right)=\frac{1}{b^2} \sum_{\left(l, l^{\prime}\right) \in \mathcal{I}}\left\|G_T^{(l)}-G_S^{\left(l^{\prime}\right)}\right\|_F^2 LSP(GT,GS)=b21(l,l)I GT(l)GS(l) F2
其中 I \mathcal{I} I表示为两个网络中的每一对的层(就像上面讲的,深度一样的话就一一对应,不一样的话就按照块对应)。 ∣ ∣ ⋅ ∣ ∣ F ||\cdot||_F ∣∣F表示为Frobenius 标准化。最终,对于学生网络来说我们可以定义整个的损失:
L = L C E ( y , σ ( z S ) ) + γ L S P ( G T , G S ) \mathcal{L}=\mathcal{L}_{\mathrm{CE}}\left(\mathbf{y}, \sigma\left(\mathbf{z}_S\right)\right)+\gamma \mathcal{L}_{\mathrm{SP}}\left(G_T, G_S\right) L=LCE(y,σ(zS))+γLSP(GT,GS)

三. 一些总结

其实思路还是很清晰的,这次蒸馏是在每一层的激活后进行相似度的判断的。代码的话等我找到进行更新。

你可能感兴趣的:(每日一次AI论文阅读,论文阅读,知识蒸馏,ICCV,集成学习)