提高几乎所有机器学习算法性能的一个非常简单的方法是在相同的数据上训练许多不同的模型,然后对其预测进行平均化[3]。不幸的是,使用整个模型集合进行预测是很麻烦的,而且计算成本太高,不允许部署到大量的用户,特别是如果单个模型是大型神经网络。Caruana和他的合作者已经表明,有可能将集合的知识压缩到一个单一的模型中,这样更容易部署,我们使用不同的压缩技术进一步发展这种方法。我们在MNIST上取得了一些令人惊讶的结果,我们表明,通过将模型集合中的知识提炼成一个单一的模型,我们可以显著改善一个大量使用的商业系统的声学模型。我们还介绍了一种由一个或多个完整模型和许多专家模型组成的新型集合体,这些模型学会了区分完整模型所混淆的细粒度的类别。与专家的混合模型不同,这些专家模型可以被快速和平行地训练。
许多昆虫都有一个幼虫形态,它为从环境中提取能量和营养物质而优化,还有一个完全不同的成虫形态,它为旅行和繁殖的非常不同的要求而优化。在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同。对于像语音和物体识别这样的任务,训练必须从非常大的、高度冗余的数据集中提取结构,但它不需要实时操作,而且可以使用大量的计算。
然而,部署到大量的用户,对延迟和计算资源有更严格的要求。与昆虫的类比表明,我们应该愿意训练非常繁琐的模型,如果这能使我们更容易从数据中提取结构。繁琐的模型可以是一个单独训练的模型集合,也可以是一个用非常强的正则器(如dropout)训练的单一大型模型。一旦繁琐的模型被训练出来,我们就可以使用另一种训练,我们称之为 “提炼”,将繁琐模型的知识转移到更适合部署的小模型中。这种策略的一个版本已经由Rich Caruana和他的合作者开创了。在他们的重要论文中,他们令人信服地证明了由大型模型组合获得的知识可以转移到一个小型模型中。
一个概念上的障碍可能阻止了对这一非常有前途的方法进行更多的研究,那就是我们倾向于用学习到的参数值来识别训练过的模型中的知识,这使得我们很难看到如何改变模型的形式而保持相同的知识。对知识的一个更抽象的观点是,它不受任何特定实例化的限制,它是一种后天习得的知识从输入向量到输出向量的映射。对于学习区分大量类的繁琐模型,通常的训练目标是使正确答案的平均对数概率最大化,但学习的一个副作用是,训练过的模型为所有的错误答案分配概率,即使这些概率非常小,其中一些也比其他的大得多。错误答案的相对概率告诉我们这个繁琐的模型是如何趋于一般化的。例如,一辆宝马的图像可能只有很小的几率被误认为是一辆垃圾车,但这种错误的可能性仍然比把它误认为是一根胡萝卜的可能性高很多倍。
人们普遍认为,用于训练的目标函数应尽可能准确地反映使用者的真正目标。尽管如此,当真正的目标是很好地推广到新数据时,通常训练模型来优化训练数据的性能。训练模型进行良好的泛化显然更好,但这需要关于泛化的正确方法的信息,而这些信息通常是不可用的。然而,当我们从一个大模型中提取知识到一个小模型中时,我们可以训练小模型以与大模型相同的方式进行归纳。如果繁琐的模型可以很好地泛化,例如,因为它是不同模型的大型集合的平均值,那么用相同方式训练的小模型在测试数据上的表现通常会比在训练集合的相同训练集上以正常方式训练的小模型好得多。
将繁琐模型的泛化能力转移到小模型上的一种明显的方法是将繁琐模型产生的类概率作为训练小模型的“软目标”。对于这个转移阶段,我们可以使用相同的训练集或单独的“转移”集。当复杂模型是较简单模型的大集合时,我们可以使用它们各自预测分布的算术或几何平均值作为软目标。当软目标具有高熵时,它们在每个训练案例中提供的信息要比硬目标多得多,在训练案例之间的梯度方差也要小得多,因此小模型通常可以用比原始繁琐模型少得多的数据进行训练,并使用更高的学习率。
对于像MNIST这样的任务,繁琐的模型几乎总是产生非常高置信度的正确答案,关于学习函数的大部分信息存在于软目标中非常小的概率的比率中。例如,一个版本的2可能被给出10 - 6的概率是3,10 - 9的概率是7,而另一个版本可能是相反的情况。这是有价值的信息,它定义了数据上丰富的相似结构(例如,它说哪些2看起来像3,哪些看起来像7),但在传递阶段,它对交叉熵代价函数的影响非常小,因为概率非常接近于零。
Caruana和他的合作者通过使用logit(最终softmax的输入)而不是softmax产生的概率作为学习小模型的目标来规避这个问题,他们最小化了繁琐模型产生的logit和小模型产生的logit之间的平方差。我们更通用的解决方案称为“蒸馏”,即提高最终softmax的温度,直到繁琐的模型产生合适的软目标集。然后我们在训练小模型时使用相同的高温来匹配这些软目标。我们稍后将说明,匹配这个繁琐模型的对数实际上是蒸馏的一种特殊情况。
用于训练小模型的传输集可以完全由未标记的数据组成,也可以使用原始的训练集。我们发现,使用原始训练集效果很好,特别是如果我们在目标函数中添加一个小项,鼓励小模型预测真正的目标,并匹配繁琐模型提供的软目标。
通常情况下,小模型不能精确匹配软目标,在正确答案的方向上出错是有帮助的。
神经网络通常通过使用“softmax”输出层产生类概率,该输出层通过将zi与其他logit进行比较,将为每个类计算的logit zi转换为概率qi。其中T是通常设置为1的温度。使用较高的T值会在类上产生较软的概率分布。在最简单的蒸馏形式中,知识被转移到蒸馏模型上,方法是在一个转移集中训练它,并对转移集中的每个情况使用软目标分布,该转移集中由使用softmax中具有高温的繁琐模型生成。训练蒸馏模型时使用相同的高温,但训练后它使用的温度为1。
当所有或部分传输集都知道正确的标签时,通过训练蒸馏模型生成正确的标签,可以显著改进该方法。一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是软目标的交叉熵,这个交叉熵是使用蒸馏模型的softmax中的相同高温计算的,就像从麻烦的模型生成软目标时使用的一样。第二个目标函数是带有正确标签的交叉熵。这是使用蒸馏模型的softmax中完全相同的logits计算的,但温度为1。我们发现,在第二个目标函数上使用一个可调节的较低权重通常可以获得最佳的结果。由于软目标产生的梯度大小为1/ t2,因此在使用硬目标和软目标时,重要的是将其乘以t2。这确保了在使用元参数进行实验时,如果用于蒸馏的温度发生改变,则硬目标和软目标的相对贡献大致保持不变。
传输集中的每种情况都贡献了一个交叉熵梯度,dC/dzi,相对于蒸馏模型的每个logit, zi。如果繁琐的模型有logits vi,产生软目标概率pi,并且在温度为T时进行转移训练,则该梯度为:
如果温度与logit的大小相比较高,我们可以近似:
因此,在高温极限下,蒸馏等价于最小化1/2(zi−vi)2,前提是对数对每个转移情况分别为零。在较低的温度下,蒸馏对比平均值负得多的logit的匹配关注要少得多。这是潜在的优势,因为这些logit几乎完全不受用于训练繁琐模型的成本函数的约束,因此它们可能非常嘈杂。另一方面,非常消极的logit可能会传递关于繁琐模型所获得的知识的有用信息。这些效应中哪一种占主导地位是一个经验问题。我们表明,当蒸馏模型太小,无法捕获繁琐模型中的所有知识时,中间温度工作得最好,这强烈表明忽略大的负对数是有帮助的。
为了了解蒸馏工作的效果,我们在所有60,000个训练案例中训练了一个大型神经网络,它有两个隐藏层,包含1200个整流线性隐藏单元。如[5]所述,使用退出和权重约束对网络进行了强正则化。Dropout可以被看作是一种训练共享权重的指数级大的模型集合的方法。此外,输入图像为向任何方向抖动最多两个像素。该网络实现了67个测试误差,而一个较小的具有两个隐藏层的800个校正线性隐藏单元和没有正则化的网络实现了146个误差。但是,如果只通过增加匹配大网在20度温度下产生的软目标的额外任务来正则化小网,它就实现了74个测试误差。这表明,软目标可以将大量的知识转移到蒸馏模型中,包括从翻译的训练数据中学习到的关于如何泛化的知识,即使转移集不包含任何翻译。
当蒸馏网的两个隐藏层中每层含有300或更多单位时,所有高于8的温度都给出了相当相似的结果。但当这一数值从根本上降低到每层30个单位时,2.5到4的温度比更高或更低的温度效果要好得多。
然后,我们试着从转移集中删除数字3的所有示例。所以从蒸馏模型的角度来看,3是一个从未见过的神秘数字。尽管如此,蒸馏模型只有206个测试错误,其中133个是在测试集中的1010个3上。大多数的错误是由于3类的学习偏差过低造成的。如果这个偏差增加3.5(这将优化测试集中的总体性能),蒸馏模型将犯109个错误,其中14个错误发生在3秒。因此,在正确的偏差下,蒸馏模型得到了98.6%的测试3的正确率,尽管在训练中从未见过3。如果转移集只包含来自训练集的7s和8s,则蒸馏模型产生47.3%的测试误差,但当7和8的偏差减少7.6以优化测试性能时,这降至13.2%的测试误差。
在本节中,我们研究了用于自动语音识别(ASR)的集成深度神经网络(DNN)声学模型的效果。我们表明,我们在本文中提出的蒸馏策略达到了预期的效果,即将模型集合蒸馏为单个模型,该模型比直接从相同训练数据学习到的相同大小的模型工作得更好。
目前,最先进的ASR系统使用dnn将来自波形的特征的(短)时间上下文映射到隐藏马尔可夫模型(HMM)[4]的离散状态的概率分布。更具体地说,DNN每次在三电话状态集群上产生一个概率分布,然后解码器在HMM状态中找到一条路径,这是使用高概率状态和在语言模型下产生可能的转录之间的最佳折衷方案。
尽管可以(并且可取)训练DNN,通过对所有可能的路径进行边缘化,将解码器(以及语言模型)考虑在内,但通常情况下,训练DNN执行逐帧分类,方法是(局部地)最小化网络做出的预测与每个观察的状态的基真值序列强制对齐所给出的标签之间的交叉熵:
其中θ是声学模型P的参数,该模型将时间t, st的声学观测映射到概率P (ht|st;θ ')的“正确”HMM状态ht的值,它是由与正确的单词序列强制对齐确定的。该模型采用分布式随机梯度下降法训练。
我们使用了一个包含8个隐藏层的架构,每个隐藏层包含2560个整流线性单元和一个带有14000个标签(HMM目标ht)的最终softmax层。输入为26帧,每帧有40个Mel-scaled滤波器组系数,每帧提前10ms,我们预测第21帧的HMM状态。参数总数约为85M。这是Android语音搜索使用的声学模型的一个稍微过时的版本,应该被视为一个非常强大的基线。为了训练DNN声学模型,我们使用了大约2000小时的英语口语数据,产生了大约700M个训练示例。在我们的开发集上,该系统实现了58.9%的帧精度和10.9%的字错误率。
我们训练了10个独立的模型来预测P (ht|st;θ),使用与基线完全相同的架构和训练过程。使用不同的初始参数值对模型进行随机初始化,我们发现这在训练过的模型中创建了足够的多样性,使集合的平均预测显著优于单个模型。我们已经探索了通过改变每个模型看到的数据集来增加模型的多样性,但我们发现这不会显著改变我们的结果,所以我们选择了更简单的方法。对于蒸馏,我们尝试了[1,2,5,10]的温度,并对硬目标的交叉熵使用了0.5的相对权重,其中粗体表示表1中使用的最佳值。
表1表明,的确,我们的蒸馏方法能够从训练集中提取更多有用的信息,而不是简单地使用硬标签来训练单个模型。使用10个模型的集合所取得的帧分类精度的改进中,超过80%转移到蒸馏模型中,这与我们在MNIST上的初步实验中观察到的改进相似。由于目标函数不匹配,集成在WER的最终目标上有较小的改进(在23k字测试集上),但集成实现的WER改进再次转移到蒸馏模型上。
我们最近注意到通过匹配已经训练过的较大模型[8]的类概率来学习小型声学模型的相关工作。然而,他们使用一个未标记的大型数据集在1的温度下进行蒸馏,他们的最佳蒸馏模型只减少了小模型的错误率28%,当它们都使用硬标签训练时,大模型和小模型的错误率差距。
训练模型的集合是利用并行计算的一种非常简单的方法,通常反对集合在测试时需要太多的计算,可以通过使用蒸馏来解决。然而,对于集成还有另一个重要的反对意见:如果单个模型是大型神经网络,数据集非常大,训练时所需的计算量是过多的,即使它很容易并行化。
在本节中,我们将给出这样一个数据集的示例,并展示如何学习专家模型,每个模型都关注于不同的易混淆的类子集,从而减少学习集成所需的总计算量。专注于细粒度区分的专家的主要问题是他们很容易过拟合,我们描述了如何通过使用软目标来防止这种过拟合。
JFT是一个内部的谷歌数据集,它有1亿张带有15,000个标签的标记图像。当我们做这项工作时,谷歌的JFT基线模型是一个深度卷积神经网络[7],它已经在大量核上使用异步随机梯度下降训练了大约6个月。本训练使用了两种类型的并行[2]。首先,有许多神经网络的副本运行在不同的核集上,并处理来自训练集的不同小批次。每个副本计算当前小批处理上的平均梯度,并将该梯度发送到一个分片参数服务器,该服务器发回参数的新值。这些新值反映了自上次向副本发送参数以来参数服务器接收到的所有梯度。其次,每个副本通过在每个核上放置不同的神经元子集而分布在多个核上。集成训练是第三种可包装的并行但前提是有更多的内核可用。等待数年来训练一个模型集合不是一个选择,因此我们需要一个更快的方法来改进基线模型。
当类的数量非常大时,将繁琐的模型设置为一个集合是有意义的,该集合包含一个经过所有数据训练的通才模型和许多“专家”模型,其中每个“专家”模型的训练数据都高度丰富,来自非常容易混淆的类子集(就像不同类型的蘑菇)。这类专家的softmax可以通过将所有它不关心的类合并到一个单独的垃圾箱类中而变得更小。
为了减少过拟合和分担学习底层特征检测器的工作,每个专家模型都用通才模型的权值初始化。然后,通过训练专家,用一半来自其特殊子集的样本和一半来自训练集剩余部分的随机抽样,对这些权重进行轻微修改。训练后,我们可以通过将垃圾箱类的logit乘以专家类过采样比例的对数来修正有偏差的训练集。
为了为专家们推导出对象类别的分组,我们决定专注于我们的完整网络经常混淆的类别。尽管我们可以计算混淆矩阵并使用它作为查找此类聚类的方法,但我们选择了一种更简单的方法,不需要真实的标签来构造聚类。
特别地,我们对我们的通才模型的预测的协方差矩阵应用了一种聚类算法,这样一组经常被一起预测的类Sm将被用作我们的一个专家模型m的目标。我们对协方差矩阵的列应用了一种在线版的K-means算法,并得到了合理的聚类(如表2所示)。我们尝试了几种聚类算法,得到了类似的结果。
在调查当专家模型被蒸馏时会发生什么之前,我们想看看包含专家的整体表现如何。除了专家模型之外,我们总是有一个通才模型,这样我们就可以处理那些没有专家的类,从而可以决定使用哪些专家。给定一个输入图像x,我们用两个步骤进行top-one分类:
第1步:对于每个测试用例,我们根据通才模型找到n个最可能的类。
称这组类为k。在我们的实验中,我们使用n = 1。
第二步:我们取所有专家模型m,其可混淆类的特殊子集Sm与k有一个非空交集,称之为专家Ak的活动集(注意这个集可能是空的)。然后我们在所有的类中找到最小的全概率分布q:
其中KL为KL散度,pm pg为专家模型或通才全模型的概率分布。分布pm是m的所有专家类加上一个垃圾箱类的分布,因此当从全q分布计算其KL散度时,我们将全q分布分配给m的垃圾箱中的所有类的所有概率相加。
公式5没有一个通用的封闭形式的解,尽管当所有的模型为每个类别产生一个单一的概率时,解是算术或几何平均值,这取决于我们是使用KL(p, q)还是KL(q, p))。我们参数化q = sof tmax(z) (T = 1),并使用梯度下降优化logits z w.r.t eq. 5。注意,必须对每个图像进行这种优化。
从训练过的基线全网络开始,专家训练速度极快(JFT只需几天而不是数周)。而且,所有的专家都是完全独立训练的。表3给出了基线系统和结合专家模型的基线系统的绝对测试精度。在61个专家模型中,测试准确度总体上有4.4%的相对提高。我们还报告了条件检验的准确性,即仅考虑属于专家类的例子,并将我们的预测限制在该类的子集中的准确性。
对于我们的JFT专家实验,我们训练了61个专家模型,每个模型有300个类(加上垃圾箱类)。因为专家的类集不是互不相干的,所以我们经常有多个专家覆盖一个特定的图像类。表4显示了测试集示例的数量,使用专家时在位置1正确的示例数量的变化,以及按覆盖类的专家数量分解的JFT数据集top1准确性的相对改进百分比。当我们有更多的专家覆盖一个特定的类时,精确度的提高会更大,这一总体趋势让我们感到鼓舞,因为训练独立的专家模型非常容易并行化。
关于使用软目标而不是硬目标,我们的一个主要主张是,在软目标中可以携带许多有用的信息,而这些信息不可能由单个硬目标编码。在本节中,我们通过使用更少的数据来拟合前面描述的基线语音模型的85M参数,来证明这是一个非常大的影响。表5显示,在只有3%的数据(约20M个示例)的情况下,用硬目标训练基线模型会导致严重的过拟合(我们进行了早期停止,因为准确性在达到44.5%后急剧下降),而用软目标训练的相同模型能够恢复整个训练集中的几乎所有信息(约2%的不足)。更值得注意的是,我们不需要提前停止:具有软目标的系统简单地“收敛”到57%。这表明,软目标是一种非常有效的方式,可以将基于所有数据训练的模型所发现的规律传递给另一个模型。
我们在JFT数据集的实验中使用的专家将他们所有的非专家类分解为一个垃圾箱类。如果我们允许专家在所有类上都有一个完整的软最大,可能会有一个比使用早期停止更好的方法来防止它们过拟合。专家接受的是在其特殊类别中高度丰富的数据训练。这意味着它的训练集的有效规模要小得多,并且它对特殊类有很强的过拟合倾向。这个问题不能通过使专家类更小来解决,因为这样我们就失去了从建模所有非专家类中获得的非常有用的转移效果。
我们使用3%的语音数据进行的实验有力地表明,如果用通才的权重对一个专家进行初始化,除了用硬目标对它进行训练外,我们还可以用非特殊类的软目标对它进行训练,从而使它保留几乎所有关于非特殊类的知识。软性目标可以由通才提供。我们目前正在探索这种方法。
使用接受过数据子集训练的专家与混合专家[6]有一些相似之处,这些专家使用门限值网络来计算将每个示例分配给每个专家的概率。在专家们学习处理分配给他们的例子的同时,门控网络也在学习选择将每个例子分配给哪些专家,这是基于专家们对该例子的相对辨别性能。使用专家的判别性能来确定学习到的分配比简单地聚类输入向量并将一个专家分配到每个聚类要好得多,但它使训练难以并行化:首先,每个专家的加权训练集会根据其他所有专家的情况不断变化,其次,门限值网络需要比较不同专家在同一例子上的表现,以知道如何修改其分配概率。这些困难意味着混合专家很少用于他们可能最有益的情况:包含明显不同子集的庞大数据集的任务。
并行化多个专家的培训要容易得多。我们首先训练一个通才模型,然后使用混淆矩阵来定义训练专家的子集。一旦确定了这些子集,专家就可以完全独立地接受培训。在测试时,我们可以使用来自通才模型的预测来决定哪些专家是相关的,并且只有这些专家需要运行。
我们已经证明,在将知识从一个集合或从一个大型的高度正则化模型转移到一个更小的蒸馏模型时,蒸馏非常有效。在MNIST上,即使用来训练蒸馏模型的转移集缺少一个或多个类的例子,蒸馏也能很好地工作。对于Android语音搜索使用的深度声学模型,我们已经证明,通过训练一个深度神经网络集合所实现的几乎所有改进都可以被提炼为一个同样大小的单一神经网络,这要容易部署得多。
对于真正的大型神经网络,甚至训练一个完整的集合都是不可行的,但我们已经证明,一个经过很长时间训练的真正的大型网络,可以通过学习大量的专家网络来显著提高性能,每个专家网络都学会了在高度易混淆的簇中区分类。我们还没有证明,我们可以将专家的知识提炼回单一的大网中。