联邦学习以轮为单位,每轮包含设备选择、参数分发、本地更新和全局更新这4个步骤
服务器端选择该轮参与训练的设备,在设备选择阶段,有的设备可能处于离线状态,需要选择在线的设备,并且设备的电量、网络状况等符合一定的要求,被选中的设备参与该轮训练。
本地设备下载到服务器发来的新的模型参数后,在此基础上,用本地的数据训练更新模型。
设备把本地更新了的模型发到服务器端,服务器按照一定规则进行聚合,对全局模型进行更新
联邦学习的算法,源自分布式机器学习。2015年CCS中提出了分布选择随机梯度下降(DSSGD)算法,它是一个异步的协议,分为下载、训练和上传三个阶段。下载阶段,客户端可以选择一部分参数来进行更新本地模型;训练阶段,客户端在本地进行训练;上传阶段,客户端可以选择本地模型的一部分参数上传给服务器。每个客户端完成一次训练,马上将最新的参数选择一部分进行上传,服务器立即更新全局模型,然后进行广播。
之前的DSSGD存在的问题是通信量巨大,而且是异步的,不能用于很多用户的场景。2017年谷歌的Mcmahan等提出了FedAvg算法,是一个同步的协议,全局更新的每一轮可以有上百个客户端,进行加权平均,是目前主流的联邦学习算法。
在上面的算法中,所有的梯度都是以明文的形式给出的,然而,从梯度会泄露用户的个人信息,在最新的NeurIPS 2019中,《Deep Leakage from Gradients》一文指出,从梯度可以推断出原始的训练数据,包括图像和文本数据。谷歌的Bonawitz等人,提出了安全聚合SMPC加密方案,服务器只能看到聚合完成之后的梯度,不能知道每个用户的私有的真实梯度值。
在说安全聚合SMPC前,先说说一个常用的密钥交换协议——DH密钥交换。DH密钥交换的目的,是让想要通信的Alice、Bob双方,他们之间能够拥有一个私密的密钥,这个密钥只有A和B两个人知道。DH密钥交换包含如下步骤:
1. 首先,Alice和Bob商量好DH的参数,一个大数素数 P \mathcal{P} P,和 z p \mathbb{z}_p zp上的一个生成元 G G G( 1 < G < P 1
2. Alice和Bob都各自产生一个随机数,A和B是Alice和Bob的私钥
3. Alice和Bob分别计算 G A = G A ( m o d P ) G^A=G^A(mod \mathcal{\;P}) GA=GA(modP)和 G B = G B ( m o d P ) G^B=\mathcal{G}^B(mod \mathcal{\;P}) GB=GB(modP), G A G^A GA和 G B G^B GB是Alice和Bob的公钥。(由公钥推导出私钥是困难的)
4. Alice和Bob分别将公钥发送给对方
5. Alice收到Bob发来的他的公钥 G B G^B GB,计算出用来和Bob秘密通信的密钥 s A B = ( G B ) A ( m o d P ) s_{AB}=(G^B)^A(mod \mathcal{\;P)} sAB=(GB)A(modP);同理Bob收到alice发来的他的公钥 G A G^A GA,计算出用来和Bob秘密通信的密钥 s B A = ( G A ) B ( m o d P ) s_{BA}=(G^A)^B(mod \mathcal{\;P)} sBA=(GA)B(modP)。显然 s A B = s B A s_{AB}=s_{BA} sAB=sBA是相等的,他们在公开环境中,可以通过密钥建立私有通信通道,使用该密钥来加密消息。
联邦学习中的安全聚合是基于安全多方计算的,安全多方计算是基于秘密分享的,秘密分享由1978年被Shamir提出(RSA中的S)。
L e m m a r 1 \mathcal{Lemmar \;1} Lemmar1:一个二维平面上, 给出任意 k k k个点 ( x 1 , y 1 ) , . . . , ( x k , y k ) (x_1,y_1), ... ,(x_k,y_k) (x1,y1),...,(xk,yk)的坐标,有且仅有一个 k − 1 k-1 k−1次的多项式 q ( x ) q(x) q(x),对于所有给定的 x i x_i xi,使得 q ( x i ) = y i q(x_i)=y_i q(xi)=yi。
假设秘密 s = f ( 0 ) s=f(0) s=f(0), s s s被分享给 n = 3 n=3 n=3个用户,阈值 t = 2 t=2 t=2。
在原始的FedAVG中,用户 u u u发送更新值 y u y_u yu给服务器, y u y_u yu是其真实的模型更新值,服务器进行聚合,再按照聚合规则取平均:
用户u和v之间通过DH建立秘密通信通道,他们之间知道一个秘密随机数 s u v s_{uv} suv。用户1发送给服务器的更新值 y 1 y_1 y1是真实值 x 1 + 0 − ( s 12 + s 13 ) x_1+0-(s_{12}+s_{13}) x1+0−(s12+s13),这样服务器收到 y 1 y_1 y1时,并不知道 x 1 x_1 x1是多少。服务器对收到的所有值进行聚合以后,它们正负才会抵消,相当于真实值的聚合,等同于FedAVG。
上面的方案是存在问题的,假设用户2在上传 y 2 y_2 y2时掉线了,没有把 y 2 y_2 y2发送给服务器,那么这一轮全局更新中,服务端的聚合值 ∑ y i \sum{y_i} ∑yi是没有意义的。
因为上面的方案存在用户掉线后,聚合值失效的问题,所以考虑带恢复的方案。当用户2掉线时,在恢复阶段,它的值 s 12 s_{12} s12和 s 23 s_{23} s23用户1和用户3是知道的,服务器询问用户1和用户3,用户1和用户3进行报告。在恢复阶段结束后,服务器完成聚合。
完整的SMPC的方案,可以在谷歌的论文中读到,可以参阅《Practical Secure Aggregation for Privacy-Preserving Machine Learning》。
图大部分引用自论文:
[1] Jack Sullivan. “Secure Analytics: Federated Learning and Secure Aggregation”. Jan 2020. URL: https://inst.eecs.berkeley.edu/~cs261/fa18/scribe/10_15_revised.pdf
谷歌的方案:
[2] Bonawitz K, Ivanov V, Kreuter B, et al. Practical Secure Aggregation for Privacy-Preserving Machine Learning[C]. CCS, 2017: 1175-1191.
谷歌FL论文作者演讲PPT:
[3] Jakub Konečný. “Federated Learning-Privacy-Preserving Collaborative Machine Learning without Centralized Training Data”. Jan 2020. URL: http://jakubkonecny.com/files/2018-01_UW_Federated_Learning.pdf
[4] Jakub Konečný. “Federated Learning”. FL-IJCAI’19 ppts. Jan 2020. URL: http://fml2019.algorithmic-crowdsourcing.com/