干货!用于时空数据预测的跨节点联邦学习图神经网络

点击蓝字

干货!用于时空数据预测的跨节点联邦学习图神经网络_第1张图片

关注我们

AI TIME欢迎每一位AI爱好者的加入!

传感器网络、可穿戴设备以及物联网设备产生的大量数据使能够利用分布式数据的时空数据建模方法显得更为重要,尤其是在边缘计算和数据访问控制的需求出现之后。近来兴起的联邦学习能够在避免直接在节点之间共享数据的条件下训练机器学习模型,但如何在联邦学习中有效利用不同节点之间的时空依赖关系仍然需要研究。另-方面,已有的时空数据预测模型忽视了不同节点之前数据共享的限制。我们在本工作中提出了一种跨节点的联邦学习图神经网络(CNFGNN) .CNFGNN能够在不直接访问节点数据的前提下利用图神经网络嵌入节点之间的图结构信息。CNFGNN解耦了时间和空间维度的建模并将二者分别限制在客户端和服务器端,同时使用交替优化的方法减少通信成本和客户端的计算成本。交通流数据预测任务上的实验显示,与已有的联邦学习方法相比,CNFGNN能够在不增加客户端计算成本的情况下取得最好的预测结果,同时有适中的通信成本。

论文代码和数据集已公开: 

https:/github.com/mengez13/KDD202CNFGNN

,带来分享《用于时空数据预测的跨节点联邦学习图神经网络》。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第2张图片

孟垂正:南加州大学在读博士生,导师Yan Liu教授。主要研究兴趣包括基于物理学知识和联邦学习条件下的时空数据挖掘。

个人主页:https://mengcz13.github.io/。

01

 背  景 

使用分布式数据对时空数据进行建模的需求日渐兴起,在很多工作中,有边缘设备、传感器、物联网设备组成的网络收集的时空数据,如用户行为检测会利用个人移动设备上的数据等。在实际场景中,边缘设备常常被不同的组织或个人所拥有,从不同设备上手机数据会受到各种各样的限制或者没有访问权限。因此在实际应用中需要使用跨节点的时空数据,下图显示了跨节点分布的时空数据,红色圆圈代表数据的访问限制,表示每个服务器的数据仅仅可以被自己本身访问,不可被网络中其他节点或服务器进行访问。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第3张图片

本文认为利用分布式数据和节点中潜在的关系对时空数据建模来讲至关重要,现有的工作利用图神经网络(GNN)建模节点之间复杂的依赖关系,但这些工作的模型必须使用从所有节点上集中起来的数据来进行训练,这违反了数据跨节点分布的限制。如下图,首先利用中心服务器从所有节点设备上收集数据,然后在中心服务器上训练一个基于GNN的模型给出预测结果。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第4张图片

相比之下,基于联邦学习的方式能够高效地进行跨节点分布数据协同训练机器学习模型。下图显示了联邦学习的基本思路,边缘设备和服务器仅仅交换模型的参数,每个设备的数据和模型预测结果只允许该设备自身访问。遗憾的是,已有的联邦学习工作有的缺乏节点之间时空依赖关系考虑,如FedAvg (McMahan et al., 2017),FedProx (Li et al., 2020),Scaffold (Karimireddy et al., 2020)等方法;或者仅使用基于图结构的正则损失限制各节点上模型相似程度,如Federated Multi-Task Learning (FMTL) (Smith et al., 2017) ,但基于正则损失的方法对图结构的表达能力有限 (Kipf et al., 2017)。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第5张图片

02

 方  法 

本文通过研究发现,在有效的时空建模方法和跨节点的分布式数据之间存在研究空白,有效的建模方法无法满足分布式数据限制,而适合分布式数据的方法又缺少对时空依赖关系的有效利用。因此本文提出跨节点的联邦学习图神经网络——CNFGNN。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第6张图片

CNFGNN有以下三个特点:

i. 利用GNN捕捉多个节点的时空依赖关系,利用联邦学习满足分布式数据的限制。

ii. 解耦时间维度和空间维度的建模来联合GNN与联邦学习,使用交替优化的方法,在边缘设备使用本地数据进行时间维度建模,在中心服务器上进行空间维度建模。

iii. 在真实世界的交通流预测任务上,在不给边缘设备带来额外计算负担前提下,CNFGNN能够取得最优异的表现。

下图是CNFGNN的总体训练架构,每个节点都有一个使用本地数据提取时序信息的模型,服务器上的模型用来提取空间依赖信息。每个节点上的模型会将输入信息提取为嵌入向量,节点模型在预测时会使用全局嵌入信息为所在节点提供预测结果。服务器侧的模型从所有节点处收集嵌入向量,通过GNN在所有节点之间传播,捕捉节点之间的空间依赖关系。GNN输出即前面提到的全局嵌入信息。为了满足分布式数据要求,每个节点的输入特征、 标签和模型预测结果都只会暴露给节点本身。服务器只能获得两点信息,一是包含时序信息的输入序列的嵌入向量,二是网络的图结构,包含节点间的空间信息。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第7张图片

在每轮训练过程中,执行以下步骤:

(1) 使用已有联邦学习方法,将所有节点模型训练若干轮数。

(2) 从所有节点处收集输入序列嵌入向量。

(3) 保持节点端模型不变,使用split learning的方法,将服务器端模型训练一定轮数。

(4) 服务器将更新的全局嵌入信息发送到节点,用于下一轮的节点模型训练。

下面是在节点端使用的模型——基于GRU的编码器-解码器结构,编码器部分用来提取输入序列的时序信息,解码器部分同时使用本地输入序列中提取的信息和全局嵌入信息来为当前节点进行预测。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第8张图片

下图是服务器侧使用的模型——Graph Networks(GNs, Battaglia et al., 2018),GNs将节点的空间信息加入到全局嵌入信息中。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第9张图片

Split learning (Singh et al., 2019) 用于训练分布在不同设备上的神经网络模块,但在跨节点的时空数据预测中使用split learning训练存在一些问题。

由于输入数据和标签都在节点上,嵌入信息位于服务器,因此在训练中对于每个节点,数据流会两次跨越设备边界,因此带来通信开销。如果仅使用本地数据训练节点上的模型会高度抑制,从本地输入中提取的嵌入向量分布在不同隐空间中,增加了服务器端训练难度。

为了解决上述问题,本文采取交替训练和FedAvg相结合的方法。

(1) 固定服务器端的模型以及嵌入向量,训练节点模型,使用FedAvg对齐节点模型编码器部分的输出。

(2) 将节点模型产生的总结向量上传到服务器,固定节点模型以及总结向量,使用split learning训练服务器端模型。

03

 实  验 

本文选择交通速度预测任务来评估CNFGNN的性能,这一任务中的数据跨传感器传播。对于下面对比的所有方法,使用相同的节点模型,即节点的预测模型都是使用GRU 编码器和解码器来预测。本文使用根均方误差(RMSE)来评价预测性能。

下图的实验结果中,centralized表示集中收集数据,local表示仅使用本地节点数据进行预测。GRU + FedAvg 表现最差,因为忽略了节点间的空间依赖关系。GRU+FMTL, CNFGNN表现较好,最重要的是考虑了节点间的依赖关系。CNFGNN模型表现最好,说明GNN相对比正则化损失方式更适合时空数据预测任务。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第10张图片

下图比较了CNFGNN和表现最优的baseline——GRU+FMTL在边缘设备上的通信成本和计算成本。CNFGNN在保持计算成本基本不变和通信成本适中的情况下,很大程度提高了预测表现。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第11张图片

04

 总  结 

本文提出了跨节点的联邦学习图神经网络(CNFGNN),CNFGNN通过在联邦学习中使用图神经网络(GNNs)来填补了复杂的时空数据建模和跨节点分布式数据之间的研究空白。其表现也是在计算成本基本不变和通信成本适中的前提下出色完成了任务。

今日视频推荐

整理:AI Timer

审核:孟垂正

AI TIME欢迎AI领域学者投稿,期待大家剖析学科历史发展和前沿技术。针对热门话题,我们将邀请专家一起论道。同时,我们也长期招募优质的撰稿人,顶级的平台需要顶级的你!

请将简历等信息发至[email protected]

微信联系:AITIME_HY

AI TIME是清华大学计算机系一群关注人工智能发展,并有思想情怀的青年学者们创办的圈子,旨在发扬科学思辨精神,邀请各界人士对人工智能理论、算法、场景、应用的本质问题进行探索,加强思想碰撞,打造一个知识分享的聚集地。

干货!用于时空数据预测的跨节点联邦学习图神经网络_第12张图片

更多资讯请扫码关注

干货!用于时空数据预测的跨节点联邦学习图神经网络_第13张图片

我知道你在看

干货!用于时空数据预测的跨节点联邦学习图神经网络_第14张图片

点击“阅读原文”查看精彩回放

你可能感兴趣的:(大数据,算法,python,机器学习,人工智能)