图神经网络通用框架 —— MPNN消息传递神经网络

前言

大家好,我是阿光。

本专栏整理了《图神经网络》,内包含了不同图神经网络的原理以及相关代码实现,详细讲解图神经网络,理论与实践相结合,如GCN、GraphSAGE、GAT等经典图网络,每一个代码实例都附带有完整的代码+数据集。

正在更新中~ ✨

我的项目环境:

  • 平台:Windows10
  • 语言环境:python3.7
  • 编译器:PyCharm
  • PyTorch版本:1.11.0
  • PyG版本:2.1.0

项目专栏:【入门图神经网络】


一、图计算框架

对于空域的图神经网络常见有GAT、GCN等变体,除此之外目前还提出了几个通用的计算框架,目的就是将一些模型集成到一个框架当中。

对于常见的计算模块通常有传播模块、采样模块、池化模块等,针对每个模块目前都有很多不同的设计模式,所以一些研究人员想要提出一个通用的框架来表示这个计算流程。

二、计算模块

  • 传播模块:用于在节点之间传播信息,以便聚合的信息可以捕获特征和拓扑信息。在传播模块中,卷积算子和递归算子通常用于聚集来自邻居的信息,而跳过连接操作用于从节点的历史表示中收集信息并缓解过度平滑问题。

  • 采样模块:当图很大时,通常需要采样模块来对图进行传播。采样模块通常与传播模块相结合。

  • 池化模块:当我们需要高级子图或图的表示时,需要池模块来从节点中提取信息。

图神经网络通用框架 —— MPNN消息传递神经网络_第1张图片
Monti等人提出了混合模型网络MoNet,这是在图或流形上定义的几种方法的通用空间框架。Gelmer等人提出了消息传递神经网络MPNN,该网络使用消息传递函数来统一几个变体。Wang等人提出了非局部神经网络NLNN,它统一了几种“自我注意力”方法。Battaglia等人提出了图形网络GN,它是一个更加通用的框架,能够学习节点级、边级和图级别的表示学习。

三、MPNN消息传递神经网络

MPNN是Gilmer等人为实现分子性质预测提出的一个通用计算框架。

该模型主要包含两个阶段:消息传递阶段和读取阶段

  • 消息传递阶段(Message Passing Phase):该阶段分别两个步骤,分别是聚合邻居信息和更新状态信息
    m v t + 1 = ∑ u ∈ N v M t ( h v t , h u t , e v u ) m_v^{t+1}=\sum_{u\in N_v}M_t(h_v^t,h_u^t,e_{vu}) mvt+1=uNvMt(hvt,hut,evu)

上式用于聚合邻居,其中 M t M_t Mt 代表一个可微函数,用于聚合信息使用,最简单的设计就是一个全连接神经网络,将输入的节点与邻居节点的特征向量以及边特征进行映射形成新的维度特征,然后将所有邻居聚合后的信息进行加和。

h v t + 1 = U t ( h v t , m v t + 1 ) h_v^{t+1}=U_t(h_v^t,m_v^{t+1}) hvt+1=Ut(hvt,mvt+1)

该式是用于更新节点状态的,同理 U t U_t Ut 也是一个可微函数,输入的信息是当前节点的隐状态信息以及聚合后的信息,经过更新作为新的时间点的隐状态特征向量。

对于下标v表示当前节点,上标t代表第t个时间步,或者可以理解为第t层特征经过了几个网络层, N v N_v Nv代表节点v的所有邻居。

  • 读取阶段(Readout Phase):这个阶段主要是为整张图计算一个特征向量用于表示整张图

y ^ = R ( h v T ∣ v ∈ G ) \hat y=R({h_v^T|v\in G}) y^=R(hvTvG)

对于R一般常见有均值聚合器、最大聚合器、LSTM聚合器,类似于图像处理中的最大池化层,对于均值聚合器,我们可以将整张图所有节点的特征向量进行加权平均作为整张图的向量表示,对于LSTM聚合器就是将所有节点的特征向量输入到LSTM当中,利用LSTM网络进行提取信息然后将最后一个时间步的输出作为整张图的表示。

参考文章

  • Neural Message Passing for Quantum Chemistry
  • Graph neural networks: A review of methods and applications

你可能感兴趣的:(图神经网络,神经网络,深度学习,python)