Federated Knowledge Distillation 联邦知识蒸馏

背景介绍

机器学习在当前是应用广泛的技术,在一些应用场景中,机器学习模型用大量数据样本进行训练,这些样本由边缘设备(如手机、汽车、接入点等)产生并分散在各处。收集这些原始数据不仅会产生大量的通信开销,还可能会侵犯数据隐私。在这方面,联邦学习是一个有希望的通信效率和隐私保护的解决方案,它定期交换本地模型参数,而不共享原始数据。然而,在现代深度神经网络架构下,交换模型参数的成本非常高,因为它往往有大量的模型参数。例如,MobileBERT是一个用于自然语言处理任务的最先进的神经网络架构,有2500万个参数。在有限的无线带宽资源下,通过每轮通信交换近100MB的有效载荷来训练这样一个模型显然是很有挑战性的。

联邦学习的局限性促使了联邦蒸馏的发展。联邦蒸馏的基本思想是交换本地模型输出,其占用的存储空间通常比模型尺寸小得多(例如,MNIST数据集中的10个标签)。每隔一段时间,对每个客户端生成的本地logits进行平均,并上传到参数服务器,以便对每个ground-truth标签的各客户端的本地平均日志进行汇总和全局平均。每个工作者都会下载每个ground-truth标签的全局平均logits。最后,为了将下载的全局知识转移到本地模型中,每个客户端通过最小化自己的损失函数来更新其模型参数,此外还有一个正则器来惩罚自己对给定样本的logit和对给定样本的ground-truth的全局平均logit之间的较大差距。

知识蒸馏和互蒸馏

联邦知识蒸馏是建立在两个基本算法之上的。一个是知识蒸馏(Knowledge Distillation, KD),将预先训练好的教师模型的知识转移到学生模型中,而另一个是KD的在线版本,没有预先训练教师模型,称为互蒸馏(Co-Distillation, CD)。

知识蒸馏是近年来发展起来的一类模型压缩与加速技术,其主要是利用一个已经训练好的复杂模型(作为教师),将其学习到的决策信息(知识)迁移到另一个轻量级模型(作为学生)中,帮助和指导学生模型的训练。知识蒸馏旨在通过将知识从深度网络转移到小型网络来压缩和改进模型。知识蒸馏又可以按照知识的类别分为以下几类:

Logits KD: 作为知识蒸馏的开山之作,Hinton在2014年提出基于logits的知识蒸馏方法,主要思想在于用学生网络的预测logits去学习教师网络的输出logits,从而引导学生网络训练,可以学习到自身预测不出来的类之间的相似性知识。主要方法是通过基于温度参数的softmax函数,对输出logits进行软化,将其看作一种知识从教师端转移到学生端。

Hints KD: Fitnets是第一个考虑到模型中间隐藏层的知识蒸馏方法,其主要是将模型隐藏层的特征看作是一种知识,然后学生网络通过学习教师网络的隐藏层特征知识,可以提升学生模型自己的性能,这种方法可以和logits KD方法在一起结合使用。

Attention KD: 将神经网络的Attention Map作为知识进行蒸馏,并定义了基于激活图与基于梯度的注意力分布图,设计了注意力蒸馏的方法。这一方法将注意力也视为一种可以在教师与学生模型之间传递的知识,然后通过设计损失函数完成注意力传递,本质上来说学生模型学习到了教师模型针对输入数据权重更高的地方,即输入数据对模型的影响程度。

Similarity KD: 相似的输入会倾向于在训练的网络中引起相似的激活模式,与以前的蒸馏方法相比,学生不需要模仿教师的表示空间,而是需要在其自己的表示空间中保持与教师网络成对的相似性。保持相似性的知识蒸馏指导学生网络的训练,使在训练的教师网络中产生相似激活的输入也在学生网络中产生相似激活。更进一步来说,如果两个输入在教师网络中产生高度相似的激活,那么引导学生网络,这也会导致两个输入在学生中产生高度相似的激活;相反地,如果两个输入在教师中产生不同的激活,我们就希望这些输入在学生中也产生不同的激活。

Relation KD: 关系知识蒸馏可以转移数据示例的相互关系,作者主要将关系结构看作是一种知识,然后又通过欧式距离与余弦距离作为损失函数来传递知识,从而使得关系知识蒸馏训练学生模型形成与教师相同的关系结构。

知识蒸馏一般假定一个预先训练好的教师模型,这在联邦学习框架中一般不太现实。然而,CD作为KD的一个在线版本,不需要预先训练的教师模型。CD的关键思想是将多个模型的预测输出集合作为教师的知识,这通常比单个预测输出更准确。

联邦知识蒸馏

CD在实现快速分布式学习和高精确度方面有很大的潜力,然而它的通信效率仍然是一个待解决的问题。其根本原因可以追溯到KD,它需要学生和教师模型共同观察训练样本。对于一个在线版本的KD,这意味着所有客户端在每次损失计算时都应该对相同的样本进行预测,这就需要大量的样本交换,也可能带来潜在的隐私风险。消除这种对共同样本观察的依赖是联邦知识蒸馏的关键动机,接下以将联邦知识蒸馏用于分类任务来将详细说明。

在分类任务中,FD通过根据标签对样本进行分组,避免了上述CD中需要交换大量样本的问题,从而将CD扩展为一个具有通信效率的分布式学习框架。FD的操作被概括为以下四个步骤。

  1. 在本地训练期间,每个客户端为每个标签存储一个平均logit向量。
  2. 每个工作者定期将其本地平均logit向量上传到参数服务器上,对所有客户单上传的本地平均logit向量每个标签都进行平均,。
  3. 每个客户端从服务器上下载所有标签的全局平均logit向量。
  4. 在基于KD的本地训练中,每个客户端选择其教师的logit为下载的全局平均logit,它与当前训练样本的真实标签有相同的标签。

在MNIST数据集上的结果显示,在一个由2个客户端参与的5层卷积网络的设置下,FD比FL实现了4.3倍的收敛速度,而准确率损失却不到10%。在更复杂的实验设置下,实验结果显示对于不同数量的客户端,与FL相比,FD每轮通信总能减少约10,000倍的通信有效载荷大小。考虑到快速收敛和有效载荷大小的减少,与FL相比,FD将收敛前的总通信成本降低了40,000倍以上。尽管如此,FD仍会带来训练精度的损失,特别是在non-IID数据分布的情况下。

此外还有一些工作尝试改进联邦知识蒸馏的性能,例如FedGen提出了一种data-free知识蒸馏法来解决FL中的异构性问题。其中服务器学习一个轻量级生成器,以data-free的方式集成用户信息,然后广播给用户,使用学习到的知识作为"归纳偏置"来调节局部训练。FedGen学习一个仅从用户模型的预测规则导出的生成模型(在给定目标标签的情况下,该模型可以产生与用户预测的集合一致的特征表示)。该生成器随后被广播给用户,用户从潜在空间(生成器产生的分布空间)采样得到的增广样本帮助模型训练(该潜在空间体现从其他对等用户提取的知识)。给定一个比输入空间小得多的潜在空间,FedGen所学习的生成器可以是轻量级的,给当前的FL框架带来最小的开销。

也有人对联邦知识蒸馏有不一样的理解,例如FedKD并不将蒸馏直接用于各个客户端之间传递知识,而是用在客户端本地通过蒸馏训练教师模型和学生模型,在聚合时每个客户端只交换参数量更小的学生模型,这使得联邦学习过程的通信效率更高。在本地的蒸馏过程中,教师模型是层数更深、参数量更多的模型,它的容量使得它能更好直接从本地数据训练得好,而层数较小、参数量较小的学生模型则是经过全局聚合后的,它拥有全局知识,在蒸馏过程中能将全局知识转移到教师模型中。

你可能感兴趣的:(深度学习,机器学习,人工智能)