paper:Distilling the Knowledge in a Neural Network
提高几乎所有机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。
不幸的是,使用整个模型集合进行预测非常麻烦,并且计算成本可能太高,无法部署到大量用户,尤其是在单个模型是大型神经网络的情况下。
Caruana 和他的合作者 [1] 已经证明,可以将集成中的知识压缩到单个模型中,该模型更容易部署,并且我们使用不同的压缩技术进一步开发了这种方法。
我们在 MNIST 上取得了一些令人惊讶的结果,并且表明我们可以通过将模型集合中的知识提炼为单个模型来显着改进频繁使用的商业系统的声学模型。
我们还引入了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型混淆的细粒度类别。 与专家的混合不同,这些专业模型可以快速并行地进行训练。
许多昆虫都有幼虫形态和完全不同的成虫形态,幼虫形态经过优化,可以从环境中获取能量和营养,而成虫形态则可以满足不同的旅行和繁殖要求。
在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别等任务,训练必须从非常大、高度冗余的数据集中提取结构,但它并不需要实时操作,因此可以使用大量的计算。
然而,部署到大量用户对延迟和计算资源有更严格的要求。 与昆虫的类比表明,如果可以更轻松地从数据中提取结构,我们应该愿意训练非常繁琐的模型(后面称为大模型)。
大模型可能是单独训练的模型的集合,也可能是使用非常强大的正则化器(例如 dropout)训练的单个非常大的模型[9]。
一旦繁琐的模型训练出来,我们就可以使用不同类型的训练,我们称之为“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型。 Rich Caruana 及其合作者已经率先提出了该策略的一个版本 [1]。 在他们的重要论文中,他们令人信服地证明,通过大型模型集合获得的知识可以转移到单个小型模型中。
可能阻止对这种非常有前途的方法进行更多研究的一个概念障碍是,我们倾向于使用学习到的参数值来识别经过训练的模型中的知识,这使得我们很难看到如何改变模型的形式但保持相同的知识。
知识的一个更抽象的观点是,它是从输入向量到输出向量的学习映射,将其从任何特定的实例化中解放出来。
对于学习区分大量类别的繁琐模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是训练后的模型将概率分配给所有不正确的答案,即使这些概率非常小,其中一些也比其他概率大得多。
错误答案的相对概率告诉我们很多关于大模型如何泛化的信息。 例如,BMW的图像可能只有很小的机会被误认为是垃圾车,但这种错误的可能性仍然比将其误认为是胡萝卜高很多倍。
一般认为,用于训练的目标函数应尽可能地反映用户的真实目标。尽管如此,模型通常被训练为优化训练数据上的性能,而真正的目标是要对新数据具有良好的泛化能力。
显然,更好的做法是训练模型以便它们能够很好地泛化,但这需要关于正确泛化方式的信息,而这些信息通常是不可用的。
然而,当我们将大模型的知识提炼到小模型时,可以训练小模型与大型模型相同的方式进行泛化。
如果大模型泛化得好,例如,因为它是多个不同模型大型集合的平均,那么训练小模型以相同方式泛化,在测试数据上通常会比按照常规方式在同一个训练集上训练的小模型表现更好,训练集就是训练大模型的集合的。
将大模型的泛化能力转移到小模型的一个明显方法是使用大模型产生的class probability作为训练小模型的“soft targets”。
对于这个转移阶段,我们可以使用相同的训练集或单独的“转移”集。 当大模型是简单模型的大型集合时,我们可以使用它们各自的预测分布的算术或几何平均值作为soft targets。
当soft targets具有高熵时,它们在每个训练case中提供的信息比hard targets多得多,并且训练case之间梯度的方差要小得多,因此小模型可以用更少的数据,更大的learning rate进行训练。
对于像MNIST这样的任务,大模型几乎总以很高的置信度得出正确答案,大量关于学习function的信息寄存在soft targets中非常小概率的比率里。例如,一个版本中,2可能以10-6的概率被认为是3,10-9的概率被认为是7,而另一个版本可能恰好相反。这是有用的信息,因为它定义了数据的丰富的类似结构(即它指出哪些2看起来像3,哪些看起来像7),但在transfer阶段它对交叉熵损失函数的影响非常小,因为这些概率接近于零。
Caruana及其合作者通过使用logits(最后的softmax层的input)而不是用由softmax产生的概率作为学习小模型的target 来避开这个问题,并且他们最小化大模型和小模型产生的logits之间的平方差。更通用的解决方案,称为“蒸馏”,是将最后的softmax层的温度提高,直到大模型产出一套合适的soft target。然后训练小模型时用相同的高温,以匹配这些soft targets。我们稍后将展示,匹配大模型的logits实际上是蒸馏的一个特殊case。
(这里的 “温度” 在后面的公式中体现)
用于训练小模型的转移集可以完全由未标记数据组成[1],或者我们可以使用原始训练集。我们发现使用原始训练集效果很好,尤其是如果我们在目标函数中增加一个小项,鼓励小模型预测真实的target, 并且匹配由大模型提供的soft target。
通常,小模型无法完全匹配soft target,而在正确答案的方向上犯错被证明是有帮助的。
softmax的input称为logits, 用 z i z_{i} zi表示,
softmax的output称为概率,用 q i q_{i} qi表示。
神经网络通常用一个softmax层把logits转为概率,通过把 z i z_{i} zi与其他概率作比较。
公式里面的T就是上面说的蒸馏的温度。T通常是1. 更高的T产生更加soft的概率分布。
如何设置温度T?
在最简单的蒸馏形式中,准备一个transfer set数据集,它的label是大模型通过调高T产生的soft target,训练蒸馏模型时也要用同样的T,训练完成后T=1.
通过在transfet set上训练蒸馏模型,知识就被转移到了蒸馏模型。
同时使用label和soft target
当所有或部分transfer set的正确label已知时,还可以通过训练蒸馏模型来生成正确的标签来显着改进该方法。
一种方法是使用正确的label来修改soft target,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。
第一个目标函数是与soft target的交叉熵,并且该交叉熵是让蒸馏模型和产生soft target的 大模型用相同的温度T(softmax中)。softmax 中与用于从繁琐模型生成软目标相同的高温来计算的。
第二个目标函数是和正确label的交叉熵。 这是在蒸馏模型的 softmax 中还是用完全相同的logits计算,但T= 1。
我们发现,通常通过在第二个目标函数上使用相当低的权重来获得最佳结果。
由于soft target产生的梯度幅度 相当于缩放了 1/T2 ,因此在同时使用hard 和 soft targets时将其乘以 T 2 非常重要。 这确保了如果在元参数实验时用于蒸馏的温度T发生变化,hard和soft target的相对贡献保持大致不变。
PS: 前面introduction部分提到过,用softmax的input, 也叫logits, 代替softmax输出的概率作为学习小模型的target,来避开概率过小的问题,通过最小化大模型和小模型产生的logits之间的平方差。
现在说明这种方法为什么是蒸馏的一种形式。
transfer set中每个case都对蒸馏模型的每个logits z i z_{i} zi贡献出cross-entropy梯度 d C / d z i dC/dz_{i} dC/dzi.
如果大模型有logits v i v_{i} vi, 产生了soft target概率 p i p_{i} pi, 训练在温度T下完成.
那么梯度为:
如果温度比logits的幅度大,那么可以近似为:
假设每个transfer case的logits都是0均值的,即,
那么(3)可以简化为:
所以在温度T高时,如果logits对每个tranfer case都是0 均值,那么蒸馏等同于最小化 1 / 2 ( z i − v i ) 1/2(z_{i} - v_{i}) 1/2(zi−vi).
在T比较低时,蒸馏在matching logits上的attetion就少很多,因为它们比平均值负很多。
这是潜在的优势,因为这些logits几乎完全不受大模型的cost function的约束,因此它们可能非常noisy。
另一方面,非常负的logits可能会传达有关通过大模型获得的知识的有用信息。 这些影响中哪一个占主导地位是一个经验问题。 我们表明,当蒸馏模型太小而无法捕获大模型中的所有知识时,中间温度效果最好,这强烈表明忽略大的负logtis可能会有所帮助。
为了了解蒸馏的效果如何,我们在所有 60,000 个训练案例上训练了一个大型神经网络,该神经网络具有两个隐藏层,每个隐藏层包含 1200 个校正线性隐藏单元。 该网络使用 dropout 和权重约束进行了强烈正则化,如 [5] 中所述。 Dropout 可以被视为训练共享权重的指数级大模型集合的一种方法。 此外,输入图像在任何方向上抖动最多两个像素。 该网络出现了 67 个测试错误,而具有两个隐藏层(由 800 个校正线性隐藏单元且无正则化)的较小网络出现了 146 个错误。 但是,如果仅通过添加在 20 ℃ 的温度下匹配大网络产生的软目标的附加任务来对较小的网络进行正则化,则它会出现 74 个测试错误。 这表明soft target可以将大量知识转移到蒸馏模型中,包括如何泛化从translated训练数据中学到的知识,即使转移集不包含任何translations。
当蒸馏网络的两个隐藏层中每个都有 300 个或更多units时,所有高于 8 的温度都会给出相当相似的结果。 但当这从根本上减少到每层 30 个units时,2.5 至 4 范围内的温度明显优于更高或更低的温度。
然后,我们尝试从传输集中省略数字 3 的所有示例。 所以从蒸馏模型的角度来看,3是一个它从未见过的神话数字。 尽管如此,蒸馏模型仅出现 206 个测试错误,其中 133 个位于测试集中的 1010 个三元组上。
大多数错误是由于3这个类别的学习bias太低而引起的。 如果此偏差增加 3.5(这会优化测试集的整体性能),则蒸馏模型会出现 109 个错误,其中 14 个错误位于 3 上。 因此,在正确的偏差下,尽管在训练期间从未见过 3,但蒸馏模型在测试 3 中的正确率达到 98.6%。 如果传输集仅包含训练集中的 7 和 8,则蒸馏模型的测试误差为 47.3%,但当 7 和 8 的偏差减少 7.6 以优化测试性能时,测试误差将降至 13.2%。
我们已经证明,蒸馏对于将知识从集成或从大型高度正则化模型转移到较小的蒸馏模型非常有效。 在 MNIST 上,即使用于训练蒸馏模型的传输集缺少一个或多个类的任何示例,蒸馏也能表现得非常好。 对于 Android 语音搜索所使用的深度声学模型版本,我们已经证明,通过训练深度神经网络集合所实现的几乎所有改进都可以被提炼为相同大小的单个神经网络, 部署起来要容易得多。
对于非常大的神经网络,甚至训练一个完整的集合也是不可行的,但是我们已经证明,经过很长时间训练的单个非常大的网络的performance 可以通过学习大量的专家网络来显着提高 ,每个专家网络都学会区分高度混乱的集群中的类别(通过大量专家网络进一步区分类别,是帮助的性质,并不是蒸馏)。 我们还没有证明我们可以将专家的知识蒸馏回单一的大网络中。