FedMD: Heterogenous Federated Learning via Model Distillation论文笔记

   本文提出使用迁移学习和知识蒸馏开发了一个联邦学习框架FedMD,此框架允许不同的客户端根据其计算能力设计不同的网络结构,并且保护数据集的隐私安全和模型的隐私安全条件下联合训练出一个模型。

论文地址:FedMD: Heterogenous Federated Learning via Model Distillation 2019 NIPS


算法细节
FedMD: Heterogenous Federated Learning via Model Distillation论文笔记_第1张图片
本文设有一个共享数据集 D 0 D_{0} D0,每个客户端有本地的数据集 D k D_{k} Dk和模型 f k f_{k} fk k = 1... m k=1...m k=1...m

迁移学习。每个客户端先依次用公共数据集和自己本地的数据集来训练自己的模型。
重复下面五步
  第一步:每个客户端用自己的本地模型来预测共享数据集并将预测结果分数 f k ( x i 0 ) f_{k}(x_{i}^{0}) fk(xi0)发给服务端。需要注意的是1.并不需要预测全部共享数据集,只需要随意选取一部分。原因是这样可以在不损害性能的前提下加速度;2.这里的预测预测结果分数是指不经过softmax的结果。

  第二步:服务中心将客户端传送的分类分数取平均,得到平均分数 f ( x i 0 ) = 1 m ∑ k f k ( x ) i 0 f(x_{i}^{0}) = \frac{1}{m} \sum_{k}f_{k}(x)_{i}^{0} f(xi0)=m1kfk(x)i0,即得到各个模型的一个全局共识。需要注意的是权重 1 m \frac{1}{m} m1是可以修改的。CIFAR中,作者稍微抑制了来自两个较弱模型(0和9)的贡献。当有非常不同的模型或数据时,这些权重可能变得更重要。

  第三步:每个客户端从服务器下载平均分数 f ( x i 0 ) f(x_{i}^{0}) f(xi0)

  第四步: 模型蒸馏。 每个客户端用模型蒸馏在共享数据集去拟合这个平均分数,即各个模型去学习全局共识。

  第五步:每个客户端在本地数据集上训练模型几个epoch。

思考
  很多的联邦学习框架中,客户端都是发送模型参数给服务端从而在服务端聚合成一个全局的模型,在这篇论文却提出了发送模型在公共数据集上预测分数,在服务端上集成这些分数得到一个全局共识,客户端模型再去学习这些共识。
  我认为这样的好处至少有三个:1.每个客户端可以根据自身条件训练出适合自己的模型,而不必全部客户端的模型都一样。2.在Deep Leakage from Gradients在这篇文章中指出在可以偷取模型更新梯度的情况下可以还原出训练数据。所以如果发送模型参数显然会增加数据隐私泄露的风险,而发送预测分数则不会出现这样的风险。3.减少传输的数据量。
  当然发送预测分数不足的地方也不是没有的,每个客户端都与要牺牲一部分数据隐私来组成一个共享数据集,并且共享数据集的分布也是非常重要的。

你可能感兴趣的:(论文笔记,联邦学习,知识蒸馏,迁移学习)