【Pytorch Geometric学习】 Message Passing 函数解析

1 前言

  此文为自己学习的总结,内容多有参考其他文章内容,并在参考文献处一并给出文章链接。

2 介绍

  Pytorch Geometric和dgl是两个基于Message Passing(消息传递)的图神经网络框架,两个框架的实现方式个人感觉有很大的不同。在此处介绍pyg的消息传递方式。

  图中的卷积计算通常被称为邻域聚合或者消息传递 (neighborhood aggregation or message passing). 定义 x i ( k − 1 ) ∈ R F \mathbf{x}_i ^{( k − 1 )} \in R^F xi(k1)RF为节点 i i i在第 ( k − 1 ) (k-1) (k1) 层的特征 e j , i \mathbf e_{j,i} ej,i 表示节点 j j j到 节点 i i i的边特征,在 GNN 中消息传递可以表示为
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf{x}_{i}^{(k)}=\gamma^{(k)}\left(\mathbf{x}_{i}^{(k-1)}, \square_{j \in \mathcal{N}(i)} \phi^{(k)}\left(\mathbf{x}_{i}^{(k-1)}, \mathbf{x}_{j}^{(k-1)}, \mathbf{e}_{j, i}\right)\right) xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i))
其中 □ \square 表示具有置换不变性并且可微的函数,例如 sum, mean, max 等, γ \gamma γ ϕ \phi ϕ表示可微函数。
  在 PyTorch Gemetric 中,所有卷积算子都是由 MessagePassing 类派生而来,理解 MessagePasing 有助于我们理解 PyG 中消息传递的计算方式和编写自定义的卷积。在自定义卷积中,用户只需定义消息传递函数 ϕ \phi ϕ message(), 节点更新函数 γ \gamma γupdate() 以及聚合方式 aggr=‘add’, aggr=‘mean’ 或则 aggr=max. 具体函数说明如下:

下面具体介绍消息传递的三步曲:

2.1 Message passing消息传递

参考文献

[1]Pytorch-Geometric 中的 Message Passing 解析
[2]pytorch geometric教程一: 消息传递源码详解(MESSAGE PASSING)+实例
[3]SOURCE CODE FOR TORCH_GEOMETRIC.NN.CONV.MESSAGE_PASSING
[4]笔记:Pytorch-geometric: GAT代码超详细解读 | source node | target node | source_to_target
[5]PYG图卷积中的消息传递
[6]Pytorch中GNN的基类torch_geometric.nn.conv.MessagePassing
[7]https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html

你可能感兴趣的:(#,图机器学习,Pytorch,geometric)