代码地址:https://github.com/jingyang2017/KD_SRRL
本文通过知识提炼解决了模型压缩的问题。我们主张采用一种优化学生网络倒数第二层的输出特征的方法,因此与表征学习直接相关。为此,我们首先提出了一种直接的特征匹配方法,它只关注优化学生的倒数第二层。其次,更重要的是,由于特征匹配没有考虑到手头的分类问题,我们提出了第二种方法,将表征学习和分类解耦,利用教师的预训练分类器来训练学生的倒数第二层特征。特别是,对于相同的输入图像,我们希望教师和学生的特征在通过教师的分类器时产生相同的输出,这可以通过一个简单的L2损失来实现。我们的方法实现起来非常简单,训练起来也很直接,而且在大量的实验环境中,包括不同的(a)网络结构、(b)师生能力、(c)数据集和(d)领域,我们的方法都表现出了持续超越之前最先进的方法。
…
在最近的无监督学习和监督学习的工作中,训练一个输出特征表示丰富且强大的网络对于实现后续分类任务的高准确率至关重要,例如见(Chen等人,2020;He等人,2020)和(Kang等人,2020)。因此,在本文中,我们主张通过优化学生的倒数第二层输出特征来进行基于表征学习的知识提炼。如果我们能够有效地做到这一点,我们期望(并通过实验证明)最终得到的学生网络能够比(Hinton等人,2015)的KD论文中用logit匹配训练的网络有更好的泛化能力。
图1:我们的方法通过最小化教师和学生的倒数第二特征表示hT和hS之间的差异来进行知识提炼。为此,我们建议使用两种损失。(a)特征匹配损失LF M,和(b)所谓的Softmax回归损失LSR。与LF M相反,我们的主要贡献LS R,被设计为考虑到手头的分类任务。为此,LS R规定,对于相同的输入图像,教师和学生产生的特征在通过教师的预训练和冻结的分类器时产生相同的输出。注意,为了简单起见,没有显示使hT和hS的特征维度相同的函数。
主要贡献。为了实现上述目标,我们提出了两个损失函数。第一个损失函数类似于(Romero等人,2015;Zagoruyko和Komodakis,2017),基于直接特征匹配,但只专注于优化学生的倒数第二层特征。由于学生的表征能力较低,直接的特征匹配可能比较困难,更重要的是,它脱离了手头的分类任务,因此我们还提出了第二个损失函数:我们建议将表征学习和分类解耦,利用教师的预训练分类器来训练学生的倒数第二层特征。特别是,对于相同的输入图像,我们希望教师和学生的特征在通过教师的分类器时产生相同的输出,这可以通过一个简单的L2损失实现(见图1)。这个softmax回归投影用于保留学生特征中与分类相关的信息,但由于投影矩阵是预先训练好的(在教师训练阶段学习的),这并不影响学生特征的表现力。
主要结果。我们的方法有两个优点。(1) 它简单而直接地实现。(2)在大量的实验环境中,它的性能一直优于最先进的方法,包括不同的(a)网络结构(WideResNets,ResNets,MobileNets),(b)师生能力,(c)数据集(CIFAR-10/100,ImageNet),以及(d)领域(全精度-二进制)。
知识转移。在(Hinton等人,2015)的工作中,知识被定义为教师在最后的softmax层之后的输出。softmax输出比单热标签携带更丰富的信息,因为它们在教师学到的类间相似性方面提供额外的监督信号。与(Hinton等人,2015)类似,从教师那里提取的中间表征,如特征张量(Romero等人,2015)或注意力图(Zagoruyko & Komodakis,2017)已被用于定义用于促进学生优化的损失函数。试图像FitNets(Romero等人,2015)那样匹配整个特征张量是很难的,在某些情况下,这种方法可能会对学生的表现和收敛性产生不利影响。为了放松FitNet的假设,(Zagoruyko & Komodakis,2017)提出了注意力转移(AT),其中知识采取注意力图的形式,这是通道维度上特征张量能量的总结。Huang & Wang, 2017)提出了(Zagoruyko & Komodakis, 2017)的扩展,使用网络激活的最大平均差异作为提炼的损失项。Cho & Hariharan(2019)表明,非常准确的网络 "太好 "而不能成为好的教师,并提出通过早期停止教师的训练来缓解这一问题。最近,(Heo等人,2019a)的工作研究了网络中应该应用特征提炼的位置,并提出了边缘ReLU和专门设计的距离函数,只将有用(积极)的信息从教师转移到学生。最近,Li等人(Li et al., 2020a)提出通过从教师模型中提炼的架构知识来监督顺时针的架构搜索。另一个基于NAS的方法是在(Guan等人,2020)中提出的,其中学生对教师的损失被用来寻找与学生的学习能力相匹配的聚合权重。Passalis等人(2020)声称,传统的KD忽略了训练过程中的信息可塑性,并提出对通过教师各层的信息流建模。
特征关系转移。另一种知识提炼方法侧重于探索转移特征之间的关系,而不是实际特征本身。在(Yim等人,2017)中,通过计算教师和学生的跨层特征的Gram矩阵,然后在师生Gram矩阵对上应用L2损失来捕获特征的相关性。这项工作的局限性在于计算成本高,(Lee等人,2018)通过奇异值分解压缩特征图,在一定程度上解决了这个问题。Park等人(2019)提出了一种关系型知识提炼方法,该方法计算每个嵌入特征向量的距离和角度关系。这个想法在(Peng等人,2019)和(Liu等人,2019)中得到了进一步探讨。在(Peng等人,2019)中,提出了泰勒级数扩展,以更好地捕捉多个实例之间的关联性。在(Liu等人,2019)中,实例特征和关系分别被认为是图中的顶点和边,并提出了实例关系图来模拟跨层的特征空间转换。受语义相似的输入应该有相似的激活模式这一观察的启发,(Tung & Mori, 2019)提出了一种保留相似性的知识提炼方法,该方法引导学生模仿教师产生相似或不相似的激活。最近,(Jain等人,2020)提出通过量化的视觉词空间提炼知识,使学生的输出与教师的输出相匹配。Li等人(2020b)提出了局部相关探索框架,以表示特征空间中的局部区域的关系,其中包含更多的细节和判别模式。
最后,最近(Tian et al., 2020)在蒸馏和表征学习之间建立了类似的联系,它使用对比学习进行知识蒸馏。我们注意到,我们的损失与(Tian et al., 2020)中使用的损失没有关系,更简单,并且如第5节所示,在我们所有的实验中都优于它,往往有很大的优势。
我们用T和S分别表示教师和学生网络。我们将这些网络分成两部分。(i) 卷积特征提取器f Net, Net = {T, S},其在第i层的输出是一个特征张量,其中CiNet是输出特征维度,Hi, Wi是输出空间维度。我们还用表示由fNet学习的最后一层特征表示。 (ii) 一个投影矩阵,它将特征表示hNet投影到K类的logits:zi Net(i = 1, . . ,K),然后是温度为τ的softmax函数(对于交叉熵损失,τ=1),这些函数组合起来形成了一个转化到K类的softmax回归分类器。
知识提炼(KD)(Hinton等人,2015)对学生进行培训,其损失如下。
这样,教师和学生的分类器之间的差异就直接最小化了。
FitNets(Romero等人,2015)与中间的特征表示相匹配。对于第i层,定义了以下损失。
其中r(.)是一个用于匹配特征张量尺寸的函数。在我们的工作中,我们建议尽量减少表征hT和hS之间的差异。为了实现这一目标,我们建议使用两种损失。第一个是L2特征匹配损失。
其中,为了简化符号学,我们放弃了对r(.)的依赖。因此,LF M损失是一个简化的FitNet损失,它只关注学到的最终表征。这样做的直觉是,这个特征与分类器直接相关,因此将学生的特征强加于教师的特征,可能会对分类精度产生更大的影响。此外,人们可能会质疑为什么要像(Romero等人,2015)那样对其他中间表征进行优化,特别是当学生是一个较低表征能力的网络时。在第4节:损失应该用在哪里?中,我们确认单独的LF M有积极的影响,但其他层的特征匹配没有帮助。
我们发现LF M是有效的,但只是在有限的范围内。LF M的一个缺点,以及一般来说,所有的特征匹配损失,例如(Romero等人,2015;Zagoruyko & Komodakis,2017),是它独立处理特征空间中的每个通道维度,并忽略了最终分类的特征表示hS和hT的通道间依赖性。这与Hinton等人在(Hinton et al., 2015)中提出的直接针对分类精度的原始对数匹配损失形成对比。为了缓解上述问题,在这项工作中,我们提出了优化hS的第二个损失,它与分类准确性直接相关。为此,我们将使用教师预先训练的Softmax Regression(SR)分类器。
让我们用p表示教师网络在输入某种图像x时的输出。让我们也把同样的图像通过学生网络来获得特征hS(x)。最后,让我们把hS(x)通过教师的SR分类器,得到输出q,也见图1。(把学生网络在分类器前产生的特征提取出来,使用教师分类器去进行分类得到输出q)我们的损失被定义为:
在这一点上,我们提出以下两点意见。(1) 如果p=q(并且由于教师的分类器是冻结的),那么这意味着hS(x)=hT(x),这表明确实公式(4)优化了学生的特征表示hS(hT也被冻结)。(2) 公式(4)的损失可以写成:
现在让我们以类似的方式来写KD损失:
通过比较公式(5)和公式(6),我们看到,我们的方法唯一的区别是,冻结的、预先训练好的教师分类器被用于教师和学生。相反,在KD中,WS也被优化。这给了优化算法更多的自由度,特别是调整学生的特征提取器fS和学生的分类器WS的权重,以使损失最小。这对学生的特征表示hS的学习有影响,而这又会阻碍学生在测试集上的概括能力。我们通过第4节的实验来证实这一假设:表征的可转移性。最后,我们注意到,我们发现,在实践中,对数之间有一个L2损失:
比交叉熵损失的效果稍好。附录中给出了LSR的不同类型损失的比较。
表1:提议的损失(LF M和LSR)和蒸馏位置对CIFAR-100的测试集的影响
总的来说,在我们的方法中,我们使用三种损失来训练学生网络:
其中α和β是用来衡量损失的权重。教师网络是预训练的,在训练学生时固定下来。LCE是基于手头任务的真实标签的标准损失(例如,图像分类的交叉熵损失)。请注意,这导致了一个非常简单的训练学生的算法,总结在算法1中。
我们在CIFAR-100上进行了一组消融研究(见第5.1节),使用了教师(WRN-40-4)和学生(WRN-16-4)的广域网络(WRN)(关于网络定义,见第5节)。
LF M和LS R都有用吗?为了回答这个问题,我们做了3个实验:单独使用LF M,单独使用LSR,以及将它们结合起来使用LF M + LSR。表1的结果(前3行)清楚地表明,所有提议的变体都提供了明显的收益:当单独使用LF M和LSR时,Top-1的准确性分别提高了1%和2%。此外,当把它们结合在一起时,还获得了额外的0.4%的改进。重要的是,结果显示LSR明显比LF M更有效。在这一点上,我们进一步注意到,我们发现LF M在ImageNet实验中提供的收益越来越少。
损失应该应用在哪里?建议的损失也可以应用在网络的其他层。这对LF M来说是很直接的,我们也可以将LSR扩展到更多的层,通过使用AdaIN层将每一层的学生的平均特征转移到教师的相应层(Huang & Belongie, 2017)。一方面,在网络的早期应用损失可以确保后续层收到 "更好 "的特征。另一方面,早期层产生的特征并不专门针对某一特定类别。因此,在网络的末端应用蒸馏损失,在那里激活编码辨别性的、与任务相关的特征,应该导致潜在的更强大的模型。表1的结果(最后3行)证实了我们的假设。在网络中的多个点上应用损失,实际上反而损害了准确性。
教师和学生的相似性。知识提炼的总体目标是使学生模仿教师的输出,从而使学生能够获得与教师相似的表现。因此,为了了解学生对教师的模仿程度,我们用(a)教师和学生的输出之间的KL散度,以及(b)学生的预测和真实标签之间的交叉熵损失来衡量教师和学生的输出的相似性。
表2:在CIFAR-100的测试集上,教师和学生之间的KL散度,以及学生和真实标签之间的交叉熵。教师的前1名准确率为79.50%。
表3:L2距离:||hT-hS||2,以及在CIFAR-100的测试集上计算的NMI。
从表2可以看出,KD(Hinton等人,2015年)降低了KL散度,教师的输出提供了1.5%∼的准确度增益。A T(Zagoruyko和Komodakis,2017)也降低了KL散度,教师的输出提供了较小的准确性增益,即1.0%。此外,与KD和A T相比,建议的损失LF M和LSR以及它们的组合LF M+LSR显示出相当高的相似性。
图2:hS和hT在CIFAR-100的测试集上的可视化。以彩色方式观看效果更好。
表征的距离。表3显示了教师和学生表征hT和hS之间的二级距离。表3中的结果清楚地表明,LF M和LSR都缩小了距离,它们的组合最接近教师的距离。
归一化互信息(NMI)。此外,我们还计算了NMI(Manning et al., 2008),这是一个平衡的指标,可以用来确定特征聚类的质量。表3中的结果显示,LF M + LSR的NMI得分最高,这意味着特征的聚类效果更好。定性结果见图2,图中直观地显示了特征hS和hT。可以看出,LF M+LSR能够学习到更多的鉴别性特征,这也与定量准确性的提高有关。
表征的可转移性。按照(Tian等人,2020),本节旨在比较所学学生的表征hS的表征能力。为此,我们在CIFAR100上对学生进行了训练,然后将其作为一个冻结的特征提取器,在此基础上对2个数据集进行线性分类器训练。STL10 Coates等人(2011)和CIFAR100。我们比较了KD、CRD、LF M、LSR和LF M+LSR的转移能力。在STL上,提议的损失比KD的优势是显而易见的。重要的是,LSR在很大程度上优于KD,这证实了我们对公式(5)和(6)的分析。在STL上的最佳结果是由CRD获得的。然而,在CIFAR100,也就是目标蒸馏数据集上,我们的方法优于CRD。
我们在多个(a)网络架构(ResNet(He等人,2016)、Wide ResNet(Zagoruyko & Komodakis,2016)、MobileNetV2(Sandler等人,2018)、MobileNet(Howard等人,2017))中彻底评估了我们方法的有效性,这些网络架构具有不同的师生能力;(b)数据集(CIFAR10/100,ImageNet),以及(c)领域(实值和二值网络)。所有实验的训练细节都在附录中提供。我们用ResNet-N表示具有N个卷积层的剩余网络(He等人,2016)。我们用WRN-D-k表示具有D层和扩展率为k的WRN架构(Zagoruyko & Komodakis, 2017)。
对于上述设置,我们将我们的方法与KD(Hinton等人,2015)和A T(Zagoruyko & Komodakis,2017),以及最近的OFD(Heo等人,2019a)、RKD(Park等人,2019)、CRD(Tian等人,2020)的方法进行比较。
结果概述。从我们的实验中,我们得出结论,我们的方法在上述所有情况下都能提供一致的收益,在所有情况下都能超过所有考虑的方法。值得注意的是,我们的方法对最难的数据集(即CIFAR-100和ImageNet)特别有效。
表4:通过冻结f S并在上面训练一个线性分类器,从CIFAR100到STL-10和CIFAR100的表示的可转移性。提供了前1(%)的准确性。
表5:CIFAR-10上各种知识提炼方法的前1名准确率(%)。
5.1 CIFAR10/CIFAR100
对于CIFAR-10,我们的方法的Top-1性能显示在表5中。我们测试了代表学生和教师网络的不同网络结构的三个案例:前两个实验是用WRNs。接下来的三个实验是用ResNets。在最后一个实验中,教师和学生有不同的网络架构。总的来说,我们的方法在所有情况下都取得了最佳结果,KD(Hinton等人,2015)紧随其后。
对于CIFAR-100(Krizhevsky & Hinton, 2009),我们使用不同的结构对几个学生-教师网络进行了实验。实验被归为三组。第一组显示了使用WRN的不同师生能力的表现:差生-好老师(WRN-16-2;WRN-40-4),差生-好老师(WRN-10-10;WRN-16-10);好学生-好老师(WRN-16-4;WRN-40-4)。在第二组中,我们显示这些结果在使用不同的架构时也是成立的,在这种情况下是ResNet。最后一组是为了显示教师和学生拥有不同架构(MobileNetV2、ResNet和WRN)时的性能。
表11显示了我们方法的最高性能。我们观察到,对于几乎所有的配置,我们的方法都比以前的工作取得了一致的、明显的精度提升。此外,很难说哪种方法是第二好的,因为其余的方法在不同的配置下都有自己的优势。对于WRN实验,OFD排名第二。对于ResNet和混合结构实验,CRD排名第二。更多与其他方法的比较以及将我们的方法与KD和AT相结合所得到的结果在补充材料中提供。通过将我们的方法与其他方法相结合,可以获得进一步的改进,但这需要进行全面的调查,这超出了本文的范围。
表6:CIFAR-100上各种知识蒸馏方法的前1名准确率(%)。
5.2 Imagenet-1K
我们的实验包括两对网络,它们是ImageNet的流行设置(Russakovsky等人,2015)。第一个是从ResNet-34提炼到ResNet-18,第二个是从ResNet-50提炼到MobileNet(Howard等人,2017)。注意,按照(Tian et al., 2020)在ImageNet上的做法,对于KD,我们将KL损失的权重设置为0.9,交叉熵损失的权重设置为0.5,这有助于获得更好的准确性。
我们的结果列于表7。我们再次观察到,我们的方法比所有的竞争方法都取得了明显的改进。此外,没有任何一种方法一直处于第二位:对于ResNet-34到ResNet-18的实验,RKD是第二好的,而对于ResNet-50到MobileNet,CRD是第二好的。值得注意的是,在后者的实验中,CRD将教师和学生之间的差距缩小了1.27%,而我们的方法则将其缩小了2.36%。总的来说,我们在ImageNet上的结果验证了我们方法的可扩展性,并且表明,当应用于大规模数据集时,与其他竞争方法相比,我们取得了更有利的性能。
表7:与ImageNet上的最先进技术的比较。
表8:CIFAR-100上的全精度-二进制蒸馏结果:用一个实值教师ResNet-34来蒸馏二进制学生。ImageNet-1K上的实-二进制蒸馏结果:一个实值的ResNet-18被用来蒸馏一个二进制学生。OFD的结果可能是次优的。
5.3 二值网络蒸馏
…
5.4面部标志物检测
我们提出了一种知识提炼的方法,该方法优化了学生网络倒数第二层的输出特征,因此与表征学习有直接关系。我们的方法的一个关键是新提出的Softmax回归损失,它被认为是有效的表征学习的必要条件。我们表明,我们的方法在广泛的实验环境中始终优于其他先进的蒸馏方法,包括具有不同师生能力的多种网络结构(ResNet、Wide ResNet、MobileNet)、数据集(CIFAR10/100、ImageNet)和领域(实值和二进制网络)。