第一次开始写博客,就从前两天调研的联邦学习开始好了。
下面将以联邦学习的分类、联邦学习与分布式学习的区别、联邦学习的前景与应用以及模型鲁棒性四个方面展开。
联邦学习的孤岛数据有不同的分布特征。对于每一个参与方来说,自己所拥有的数据可以用一个矩阵来表示。设矩阵Di表示第i个参与方的数据;设矩阵Di的每一行表示一个数据样本,每一列表示一个具体的数据特征(feature)。同时,一些数据集还可能包含标签信息。我们将特征空间设为X,数据标签(label)空间设为Y,并用I表示数据样本ID空间。
特征空间X、数据标签空间Y和样本ID空间I组成了一个训练数据集(I,X,Y)。不同的参与方拥有的数据的特征空间和样本ID空间可能都是不同的。例如,在金融领域,数据标签可以是用户的信用度或者征信信息;在市场营销领域,数据标签可以是用户的购买计划;在教育领域,数据标签可以是学生的成绩分数。
根据训练数据在不同参与方之间的数据特征空间和样本ID空间的分布情况,联邦学习可被分为横向联邦学习(Horizontal Federated Learning,HFL)、纵向联邦学习(Federated Transfer Learning,FTL) 、迁移联邦学习(Vertical Federated Learning,VFL)。
横向联邦学习适用于联邦学习的参与方的数据有重叠的数据特征,即数据特征在参与方之间是对齐的,但是参与方拥有的数据样本是不同的。横向联邦学习中数据特征重叠维度较多,根据重合维度进行对齐,取出参与方数据中特征相同而用户不完全相同的部分进行联合训练。
图 横向联邦学习(按样本划分的联邦学习)
纵向联邦学习适用于联邦学习参与方的训练数据有重叠的数据样本,即参与方之间的数据样本是对齐的,但是在数据特征上有所不同。纵向联邦学习用户重合较多,根据用户ID进行匹配,取出参与方数据中用户相同而特征不完全相同的部分进行联合训练。
图 纵向联邦学习(按特征划分的联邦学习)
联邦迁移学习适用于参与方的数据样本和数据特征重叠都很少的情况。目前大部分的研究是基于横向联邦学习和纵向联邦学习的,迁移联邦学习领域的研究暂时还不多。
图 联邦迁移学习
联邦学习根据不同场景可以分为两大类:“跨设备”(cross-device)和“跨孤岛”(cross-silo)。表 1所示是两个类型的主要区别。
“跨设备”类型着重于整合大量移动端和边缘设备应用程序。跨设备的例子:Google的Gboard移动键盘,Apple正在iOS 13中使用跨设备FL用于QuickType键盘和“ Hey Siri”的人声分类器等。
“跨孤岛”类型一些可能只涉及少量相对可靠的客户端的应用程序,例如多个组织合作训练一个模型。跨孤岛的例子包括再保险的财务风险预测,药物发现,电子健康记录挖掘,医疗数据细分和智能制造等。
跨孤岛 | 跨设备 | |
---|---|---|
例子 | 医疗机构 | 手机端应用 |
节点数量 | 1~100 | 1~1010 |
节点状态 | 节点几乎稳定运行 | 大部分节点不在线 |
主要瓶颈 | 计算瓶颈和通信瓶颈 | wifi速度,设备不在线 |
按数据类型分类 | 横向/纵向 | 横向 |
联邦学习并不只是使用分布式的方式解决优化问题。联邦学习和分布式学习均是在多个计算节点上进行模型运算,但仍有不少区别,其主要区别如下表所示。
联邦学习 | 分布式训练 | |
---|---|---|
数据分布 | 分散存储且固定,数据无法互通、可能存在数据的Non-IID(非独立同分布) | 集中存储不固定,可以任意打乱、平衡地分配给所有客户端 |
节点数量 | 1~1010 | 1~100 |
节点状态 | 节点可能不在线 | 所有节点稳定运行 |
联邦学习是面向隐私保护的ML的框架,原始数据分散保存在各个设备上并进行训练,节点间数量较多且质量严重不均。服务器聚合各个本地计算的模型更新。
分布式学习则利用多个计算节点进行机器学习或者深度学习的算法和系统,其旨在提高性能,并可扩展至更大规模的训练数据和更大的模型。各节点间数据共享,任务由服务器统一分配,各节点比较均衡。
表 数据集中式分布式学习与跨孤岛/跨设备联邦学习的综合对比
数据集中式的分布式学习 | 跨孤岛的联邦学习 | 跨设备的联邦学习 | |
---|---|---|---|
设置 | 在大型但“扁平”的数据集上训练模型。客户端是单个群集或数据中心中的计算节点。 | 在数据孤岛上训练模型。客户是不同的组织(例如,医疗或金融)或地理分布的数据中心。 | 客户端是大量的移动或物联网设备 |
数据分布 | 数据被集中存储,可以在客户端之间进行混洗和平衡。任何客户端都可以读取数据集的任何部分。 | 数据在本地生成,并保持分散化。每个客户端都存储自己的数据,无法读取其他客户端的数据。数据不是独立或相同分布的。 | 与跨孤岛的数据分布一样 |
编排方式 | 中央式编排 | 中央编排服务器/服务负责组织培训,但从未看到原始数据。 | 与跨数据孤岛编排方式一样 |
广域通讯 | 无(在一个数据中心/群集中完全连接客户端)。 | 中心辐射型拓扑,中心代表协调服务提供商(通常不包含数据),分支连接到客户端。 | 与跨孤岛的广域通讯方式一样 |
数据可用性 | 所有客户端都是可用的 | 所有客户端都是可用的 | 在任何时候,只有一小部分客户可用,通常会有日间或其他变化。 |
数据分布范围 | 通常1-1000个客户端 | 通常2~1000个客户端 | 大规模并行,最多10^10个客户端。 |
主要瓶颈 | 在可以假设网络非常快的情况下,计算通常是数据中心的瓶颈。 | 可能是计算和通信量 | 通信通常是主要的瓶颈,尽管这取决于任务。通常跨设备联邦学习使用wifi或更慢的连接。 |
可解决性 | 每个客户端都有一个标识或名称,该标识或名称允许系统专门访问它。 | 与数据集中式的分布式学习一样 | 无法直接为客户建立索引(即不对用户进行标记)。 |
客户状态 | 有状态的-每个客户都可以参与到计算的每一轮中,不断地传递状态。 | 有状态的-每个客户都可以参与到计算的每一轮中,不断地传递状态。 | 高度不可靠-预计有5%或更多的客户端参与一轮计算会失败或退出(例如,由于违反了电池,网络或闲置的要求而导致设备无法使用)。 |
客户可靠性 | 相对较少的失败次数 | 相对较少的失败次数。 | 无状态的-每个客户在一个任务中可能只参与一次,因此通常假定在每轮计算中都有一个从未见过的客户的新样本。 |
数据分区轴 | 数据可以在客户端之间任意分区/重新分区。 | 固定分区。能够根据样本分区(横向)或者特征分区(纵向)。 | 根据样本固定分区(横向)。 |
如下表所示,联合学习应用领域广泛。谷歌的研究人员致力于在Gboard应用程序上从用户生成的数据增强语言建模。其他人发现联合学习非常适合医疗保健领域,可以通过在医院保留患者数据来平衡患者隐私和机器学习。物联网设备也在联合学习上获得了关注。
此外,联合学习也进入了许多其他领域,如边缘计算、网络 、机器人 、网格、联合学习增强、推荐系统 、网络安全、在线零售商、无线通信和电动汽车。
表 联邦学习应用领域
领域 | 训练设备类型 | 目标 | 训练模型 | 聚合算法 |
---|---|---|---|---|
Google Gboard 应用 | 移动电话 | 语言建模:键盘搜索建议 | Logistic Regression | FedAvg |
Google Gboard 应用 | 移动电话 | 语言建模:下一个单词预测 | RNN | FedAvg |
Google Gboard 应用 | 移动电话 | 语言建模:表情符号预测 | RNN-LSTM | FedAvg |
Google Gboard 应用 | 移动电话 | 语言建模:词汇外的学习 | RNN-LSTM | FedAvg |
医疗健康 | 医院 | 死亡率的预测 | 神经网络 | 提出的FADE |
医疗健康 | 医院 | 死亡率和住院时间预测 | 深度学习 | FedAvg |
医疗健康 | 医院、患者 | 住院治疗的预测 | Sparse SVM | 提出的CPDS |
医疗健康 | 连接到患者设备上的手机 | 医疗系统异常检测 | 神经网络 | 非加权的参数平均 |
医疗健康 | 组织 | 人类活动识别 | 深度神经网络 | n/a |
医疗健康 | 中心 | 神经系统疾病患者的脑变化分析 | 特征提取 | 交替方向乘子法 |
医疗健康 | 机构 | 脑瘤 | CNN: U-Net | FedAvg |
医疗健康 | 机构 | 影像分类 | 深度神经网络 | FedAvg |
医疗健康 | 脑电图(EEG)设备 | 脑电图信号分类 | CNN | FedAvg |
物联网系统 | 网关监控物联网设备 | 异常检测 | RNN-GRU | FedAvg |
物联网系统 | 物联网对象或协调器(云服务器-边缘设备) | 资源受限设备的轻量级学习 | 深度神经网络 | n/a |
物联网系统 | 物联网设备 | Computation Offloading | 双深度Q学习 | n/a |
物联网系统 | 移动电话和移动边缘计算服务器 | 提升物联网制造商服务 | 分块深度模型训练 | n/a |
边缘计算 | 用户设备、边缘节点 | Computation Offloading 边缘兑现 | 强化学习 | n/a |
网络 | 机械类型设备(MTD) | 资源块分配和电力传输 | 马尔科夫链 | MTDs流量模型的聚合 |
网络 | 代理 | 生成Q网络政策 | Q网络 | 多层感知器 |
机器人学 | 机器人 | 机器人导航决策 | 强化学习 | 提出的一种知识融合算法 |
联合学习增强 | 边缘节点 | 聚合频率的确定 | 基于梯度下降的ML模型 | FedAvg |
推荐系统 | 任何用户设备(包括笔记本电脑、手机) | 生成个性化推荐 | 协同过滤 | 梯度聚合以更新因素向量 |
网络安全 | 网关监控的桌面节点 | 异常检测 | 自动编码 | FedAvg |
在线零售 | 客户 | 点击流的预测 | RNN-RGU | FedAvg |
无线通信 | 增强现实使用者 | 边缘兑现 | 自动编码 | n/a |
无线通信 | 无线电设备 | 频谱管理 | 频谱使用率模型 | n/a |
无线通信 | 核心网络中的实体 | 5G核心网络 | n/a | n/a |
电动汽车 | 车辆 | 电动汽车的故障预测 | RNN-LSTM | 基于损失函数的加权平均 |
Federated Learning系统可能遭受各种故障。这些故障不仅包括非恶意故障,还包括针对训练和部署管道的显式攻击。Federated Learning的分布式性质、体系结构设计和数据约束打开了新的故障模式和攻击面。
此外,在Federated Learning中保护隐私的安全机制可以使检测和纠正这些故障和攻击成为一项特别具有挑战性的任务。接下来关于Federated Learning鲁棒性的探讨将针对恶意攻击和非恶意故障展开。
由于Federated Learning模型是通过(可能是大型的)大量不可靠的设备使用私有的、不可检查的数据集来训练模型的,因此Federated Learning可能会在训练时引入新的攻击面。恶意攻击一般有两种类型的攻击方式:数据中毒和模型更新中毒,模型如何规避攻击也同样需要探讨。这些攻击大致可分为训练时攻击(中毒攻击)和推理时攻击(逃避攻击)。
数据中毒
在数据污染攻击中,攻击者通过替换标签或数据的特定功能来操纵客户数据,而非直接破坏向中央节点中的联合模型。
通过用错误分类的标签替换正确标签或者注入了有毒数据的方式来攻击操纵训练阶段,导致模型本身行为不当。标签翻转是脏标签攻击的一种特殊情况,被证明是Federated Learning中的漏洞之一,使用差分隐私技术可以减轻这种攻击。
数据清理和网络修剪是专为数据中毒攻击而设计的防御措施。
模型更新中毒
模型更新中毒攻击不是向训练集中注入恶意数据,而是通过欺骗本地模型直接破坏全局模型。
在联合设置中,可以通过直接破坏客户端的更新或某种中间人攻击来执行此操作。与数据中毒攻击相比,模型中毒攻击看起来不那么自然,但却更有效。如果入侵者之间串通,则模型更新中毒攻击的有效性可能会大大提高,这种勾结可以使对手创建更有效且更难检测的模型更新攻击。
防御中毒攻击
就目前而言,少有先进方法能保护系统免受中毒攻击。区块链技术和数据完整性被提出,以获得更健壮的Federated Learning解决方案。
为了保证数据的私密性和安全性,在Federated Learning设置下,采用区块链模式建立了高效的数据访问控制,保证了大规模分布式数据的安全协作计算。考虑这样一种场景:一个客户端需要解决问题,一些客户端拥有适当的数据,而另一些客户端拥有具有足够计算资源的设备。对于这种场景,作者提出了一种加密方案,其中初始客户机创建公钥和私钥并加密模型参数。然后,适当的客户机协作利用所提供的资源和私有加密数据,以便成功地训练模型。
与数据中心训练相比,Federated Learning特别容易受到服务提供商无法控制的不可靠客户的非恶意故障的影响。与恶意攻击一样,系统因素和数据限制也加剧了数据中心设置中出现的非恶意故障。
接下来我们将讨论三种可能的非恶意故障模式:客户端报告故障,数据管道故障和嘈杂的模型更新。
客户报告失败
在Federated Learning中,每轮培训都涉及向客户端广播模型,本地客户端计算以及向中央聚合器的客户端报告。客户报告失败在跨设备Federated Learning中尤其容易发生,在这种情况下,网络带宽变得更加受约束,并且客户端设备更有可能是计算能力有限的边缘设备。
不幸的是,当使用安全聚合(SecAgg)时,无响应的客户端变得更具挑战性,尤其是如果客户端在SecAgg协议期间退出时。尽管SecAgg被设计为对大量丢失具有鲁棒性,但仍有失败的可能。此时可能需要提高SecAgg的效率或开发一种SecAgg的异步版本来改善。
数据管道故障
尽管Federated Learning中的数据管道仅存在于每个客户端中,但管道仍面临许多潜在问题。特别是,任何Federated Learning系统仍必须定义如何访问原始用户数据并将其预处理为训练数据,该管道中的错误或意外动作会极大地改变联合学习过程。尽管通常可以通过数据中心设置中的标准数据分析工具来发现数据管道错误,但Federated Learning中的数据限制使检测变得更加困难。
例如,服务器无法直接检测到功能级别的预处理问题(例如,像素倒置,连接词等)。一种可能的解决方案是使用具有差分隐私的联合方法训练生成模型,然后使用它们来合成可用于调试基础数据管道的新数据样本。开发不直接检查原始数据的通用机器学习调试方法仍然是一个挑战。
嘈杂的模型更新
除了恶意攻击之外,网络和体系结构因素,也会使得发送到服务器的模型更新变得失真。
即使客户端上的数据不是故意恶意的,它也可能具有嘈杂的功能(例如,在视觉应用程序中,客户端可能具有低分辨率的摄像头,其输出缩放到更高的分辨率)或嘈杂的标签(例如,如果用户无意中指出对某个应用程序的推荐不感兴趣)。由于这些损坏可以看作是模型更新和数据中毒攻击的温和形式,因此一种缓解策略是对防御模型更新和数据中毒攻击使用防御措施。另一种可能性是,标准的联合训练方法(例如联合平均 )固有地对少量噪声具有鲁棒性。