本文翻译自Google AI Blog的Federated Learning: Collaborative Machine Learning without Centralized Training Data
标准机器学习方法需要将训练数据集中在一台机器或数据中心。Google已经构建了一个最安全,最强大的云基础架构来处理这些数据,从而使我们的服务更加完善。现在,对于通过用户与移动设备交互进行培训的模型,我们将引入另一种方法:联邦学习。
联邦学习使移动手机能够协同学习共享预测模型,同时将所有训练数据保存在设备上,从而将机器学习的能力与在云中存储数据的需求进行分离。这超出了通过将模型训练引入设备上而使用本地模型对移动设备进行预测的范围(如Mobile Vision API 和设备上的智能回复)。
它的工作原理如下:您的设备会下载当前模型,通过学习手机上的数据来改进它,然后将这些更改汇总为一个小型的重点更新。只有对模型的更新才会使用加密通信发送到云,并在其中立即与其他用户的更新进行平均,以改进共享模型。所有培训数据都保留在您的设备上,并且云中不会存储任何单独的更新。
您的手机会根据您的使用情况在本地对模型进行个性化设置(A)。聚合(B)许多用户的更新以形成对共享模型的共识改变(C),之后重复该过程。
联邦学习可以实现更智能的模型,更低的延迟和更低的功耗,同时确保隐私。这种方法还有另一个直接的好处:除了提供共享模型的更新之外,还可以立即使用手机上的改进模型,通过您使用手机的方式为个性化体验提供动力。
我们目前正在测试Android上的Gboard中的联合学习,Google键盘。当Gboard显示建议的查询时,您的手机会在本地存储有关当前上下文的信息以及您是否单击了该建议。联合学习处理设备上的历史记录,以建议改进Gboard的查询建议模型的下一次迭代。
为了使联邦学习成为可能,我们必须克服许多算法和技术挑战。在典型的机器学习系统中,像随机梯度下降(SGD)这样的优化算法运行在云中服务器之间均匀分区的大型数据集上。这种高度迭代的算法需要与训练数据的低延迟,高吞吐量连接。但在联邦学习环境中,数据以极不均匀的方式分布在数百万台设备上。此外,这些设备具有明显更高延迟,更低吞吐量的连接,并且只能间歇性地用于训练。
这些带宽和延迟限制激发了我们的联合平均算法,与普通联合版本的SGD相比,它可以使用10-100倍的通信来训练深度网络。关键的想法是使用现代移动设备中的强大处理器来计算比简单梯度步骤更高质量的更新。由于生成一个好的模型需要较少的高质量更新迭代,因此培训可以使用更少的通信。由于上传速度通常比下载速度慢得多,我们还开发了一种新方法,通过使用随机旋转和量化压缩更新,将上传通信成本降低到另外的100倍。虽然这些方法专注于训练深度网络,但我们也设计了算法对于高维稀疏凸模型,它在点击率预测等问题上表现优异。
将此技术部署到数百万运行Gboard的异构手机需要先进的技术堆栈。设备培训使用TensorFlow的微型版本。仔细调度可确保仅在设备空闲,插入和免费无线连接时才进行培训,因此不会影响手机的性能。
只有在不会对您的体验产生负面影响的情况下,您的手机才会参与联邦学习。然后,系统需要以安全,高效,可扩展和容错的方式通信和聚合模型更新。只有研究与这种基础设施的结合才能使联邦学习的好处成为可能。
联邦学习无需在云中存储用户数据,但我们并没有就此止步。我们开发了一种安全聚合协议使用加密技术,因此协调服务器只能在100或1000个用户参与时解密平均更新 - 在平均之前不能检查单个设备的更新。它是同类中第一个适用于深度网络规模问题和实际连接约束的协议。我们设计了Federated Averaging,因此协调服务器只需要平均更新,这允许使用Secure Aggregation; 但是协议是通用的,也可以应用于其他问题。我们正在努力实现此协议的生产实施,并期望在不久的将来将其部署到联邦学习应用程序中。
我们的工作只是触及了可能的表面。联邦学习无法解决所有机器学习问题(例如,通过对经过仔细标记的示例进行培训来学习识别不同的犬种,而对于许多其他模型,必要的培训数据已存储在云中(如Gmail的培训垃圾邮件过滤器)。因此,谷歌将继续推进基于云的ML的最新技术,但我们也致力于持续研究,以扩大我们可以通过联邦学习解决的问题范围。例如,除了Gboard查询建议之外,我们希望根据您在手机上实际输入的内容(可以拥有自己的风格)和基于人们看到的照片类型的照片排名来改进为键盘提供动力的语言模型,分享或删除。
应用联合学习需要机器学习从业者采用新的工具和新的思维方式:模型开发,培训和评估,不能直接访问或标记原始数据,而通信成本是一个限制因素。我们相信联邦学习的用户利益可以解决有价值的技术挑战,并且希望在机器学习社区内进行广泛的对话,从而发布我们的工作。