时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling

前言

联邦学习(FL)虽然已经被广泛研究,但是对复杂的时空依赖关系进行建模以提高预测能力仍然是一个开放的问题。此外目前最优的时空预测模型假定对数据的访问不受限制,忽略了对数据共享的限制。基于此,本文提了一个基于图联邦学习的时空数据模型 Cross-Node Federated Graph Neural Network (CNFGNN),该模型在跨节点联邦学习的约束下,使用基于图神经网络(GNN)的架构对底层图结构进行编码,属于结构化联邦的一种,每个本地模型利用私有数据进行学习,并保持分散性。CNFGNN 通过分解客户端的时间动态过程和服务器的空间动态过程完成建模,利用交替优化降低通信成本,实现结构联邦中客户端的协同训练。

介绍

联邦学习(FL)实现基于多个本地客户都安的分散数据来训练模型,但是没有考虑固有的时空依赖性,或通过在模型权重的正则化中强加图结构来隐含建模。后者受到基于正则化方法和归纳式学习的限制。Cross-Node Federated Graph Neural Network (CNFGNN) 旨在跨节点联邦学习约束下有效地建立复杂的时空依赖关系。为此,CNFGNN 对时间和空间依赖关系的建模进行分解,在每个本地客户端上使用 encoder-decoder 模型来提取本地数据的时间特征,在服务器上使用基于图神经网络(GNN)的模型来捕捉本地客户端间的空间依赖关系。

与现有的依靠正则化控制本地客户端间关系的联邦学习方法相比,CNFGNN 利用基于 GNN 的显式图结构,从而带来性能的提升。然而受限于数据共享的约束,GNN 不能以数据集中的方式进行训练。基于此,CNFGNN 采用分割学习来训练空间和时间模块。为平衡通信成本,本文提出了一个基于交替优化的训练方法(alternating optimization-based procedure),与一般的分布式学习框架相比,只产生了一半的通信开销并且本文使用 FedAvg 训练所有节点的共享时间特征提取器。一般的捕捉数据之间关系的多任务学习框架。虽然在一定程度上缓解了邻域信息缺失的问题,但不像 GNN 模型那样有效,仍然存在缺乏特征交换和聚合的问题。

由于不同用户/组织所拥有的不同客户端上收集的数据可能由于边缘计算的需要或数据访问的许可问题而不允许共享,因此有必要设计一种对时空关系进行建模的算法,而无需直接交换节点级数据.

主要贡献如下:

(1)我们提出了跨节点联邦图神经网络(CNFGNN),这是一种基于GNN的联邦学习体系结构,可以捕获多个节点之间复杂的时空关系,同时确保在边缘设备上以不额外计算成本的方式保持局部生成的数据分散。

(2) 我们的建模和培训过程使基于GNN的体系结构能够在联邦学习环境中使用。我们通过分离边缘设备上的局部时间动态建模和中央服务器上的空间动态建模来实现这一点,并利用基于交替优化的过程,使用分割学习和联邦平均来更新空间和时间模块,以实现有效的基于GNN的联邦学习。

(3) 我们证明,与交通流预测任务中的相关技术相比,CNFGNN在边缘设备上以适度的通信成本在不增加额外计算成本的情况下实现了最佳的预测性能(在传输和感应设置中)。

方法

首先将节点级的时间动态关系和服务器级的空间动态关系的建模分解如下:

  1. (图 1 c)在每个节点上,一个 encoder-decoder 模型从节点的数据中提取时间特征并进行预测;
  2. (图 1 b)在中央服务器上,一个图网络(GN)传播提取的节点时间特征并输出节点嵌入,其中包括节点之间的关系信息。

步骤 1 基于本地私有数据在每个节点上本地执行。步骤 2 只涉及上传的特征和梯度。这种分解能够在跨节点联邦学习的约束下实现节点信息的交换和聚合。

时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第1张图片

时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第2张图片

 节点内部的计算操作

使用 GRU 机制构建 encoder-decoder 来建模 node-level 的时间动态信息,给定一个输入序列  ,表示第 i个节点在过去  个时间步中维度为  的特征信息。编码器依次读取整个序列,输出  作为输入序列的隐层状态:(其中  为初始化为 0 的隐层状态向量)。

为了将空间动态信息纳入模型生成每个节点的预测,将  与后面生成的包含空间信息的节点嵌入  连接起来作为解码器的初始状态向量。解码器从输入序列的最后一帧  开始,以自动回归的方式生成预测  ,并将隐藏状态向量串联起来:

本文选择预测值和真实值之间的均方误差(MSE)作为损失函数,对每个节点进行局部评估。

中央服务器计算操作

为了捕捉复杂的空间动态信息,本文采用图网络(GNs)来生成包含所有节点的关系信息的节点嵌入。中央服务器从所有节点  收集隐藏状态作为 GN 的输入。GN 的每一层都对输入特征进行更新,具体如下:

 时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第3张图片

其中  分别代表边特征,节点特征和全局特征。代表神经网络。 代表聚合函数(例如 Sum 算子)。其实上述过程就是基于消息传递的GNN,分为特征传播、信息聚合和特征更新三个过程。本文选用两层带有残差链接的 GN 作为所有实验的 GNN 网络模型。并设置( W 代表邻接矩阵)。将空矢量分配给  作为第一层 GN 的输入。服务器端 GN 输出所有节点的嵌入,并将每个节点的嵌入相应地发送给本地客户端

 节点级和空间模型的交替训练。

 跨节点联邦学习的缺点在于训练阶段的高通信成本。由于将模型的不同部分分布在不同的设备上,分割学习是一个潜在可行的训练框架,其中隐向量和梯度在设备之间进行通信。然而通过分割学习简单地训练模型时,中央服务器需要从所有节点接收隐藏状态,并在前向传播中向所有节点发送节点嵌入,然后必须从所有本地客户端节点接收节点嵌入的梯度,并在反向传播中向所有节点发送隐藏状态的梯度。假设所有的隐藏状态和节点嵌入具有相同的大小 S ,GN 模型每轮训练中传输的数据总量为  。

为了减轻训练阶段的高通信成本,CNFGNN 改为在节点上交替训练模型,在服务器上训练 GN 模型。在每一轮训练中:

  1. 固定节点嵌入 并优化 Rc 轮的 encoder-decoder 模型;
  2. 在固定节点上的所有模型的同时优化 GN 模型。

 由于节点上的模型是固定的,  在 GN 模型训练期间保持不变,服务器只需要在 GN 训练开始前从节点上获取  ,并且只需要传输节点嵌入和梯度。因此,每轮 GN 模型训练的平均数据量减少到

 为了更有效地从每个节点中提取时间特征,使用 FedAvg 算法对节点进行 encoder-decoder 模型的训练。使得所有节点共享相同的特征提取器,从而共享一个联邦时间特征隐藏空间,避免了节点上模型的潜在过拟合,并通过经验证明了更快的收敛和更好的预测性能。

 

时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第4张图片

 

时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第5张图片

 试验

 时空数据建模:交通流预测

时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第6张图片

Baselines(可见模型超参数)

  • GRU(centralized)
  • GRU + GN (centralized)
  • GRU (local)
  • GRU + FedAvg
  • GRU+ FMTL

 结果

 

时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第7张图片

时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第8张图片

 

时空数据建模的跨节点联邦图神经网络:KDD21 Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling_第9张图片

 

 

 总结

 我们提出了跨节点联邦图神经网络(CNFGNN),它通过在联邦学习环境中使用图神经网络(GNN)来弥补复杂时空数据建模分散数据处理之间的差距。我们通过使用基于分裂学习联邦平均的时空模块交替优化来解耦本地时间模型和服务器端空间模型的学习来实现这一点。我们在两个真实数据集上的交通流预测实验结果表明,与竞争技术相比,该方法具有更高的性能。我们未来的工作包括应用现有的GNN模型和采样策略,并将其集成到大规模图的CNFGNN中,将CNFGNN扩展到一个完全分散的框架,并将现有的用于图学习的隐私保护方法合并到CNFGNN中,以增强时空动态的联合学习。

 

 

你可能感兴趣的:(神经网络,人工智能,深度学习,机器学习)