Pyg消息传递源码(MESSAGE PASSING)+实例

文章目录

  • 1. MessagePassing基类
  • 2. Message 源码
    • 2.1 MessagePassing初始化
    • 2.2 MessagePassing.propagate
    • 2.3 MessagePassing.message()
    • 2.4 MessagePassing.aggregate(inputs, index, …)
    • 2.5 MessagePassing.update(aggr_out, …)
    • 2.6 MessagePassing.message_and_aggregate(adj_t, …)
  • 3 实例

1. MessagePassing基类

GNN关键的步骤就是消息传递、聚合、更新。
pytorch geometric提供了一个MessagePassing基类,它已经通过MessagePassing.propagate()实现了以上三步对应的计算过程。
我们只需定义一个继承了MessagePassing基类的class,然后根据具体的图算法来更新函数 message() 的邻域聚合方式aggr=“add”, aggr=“mean” or aggr=“max”,以及函数update(),并在自定义的图算法卷积层中的forward函数里面调用progagate函数就可以。
大致流程如下:

import torch
from torch_geometric.nn import MessagePassing

class MyConv(MessagePassing): # 定义继承了MessagePassing基类的class
    def __init__(self, in_channels, out_channels, **kwargs):
        kwargs.setdefault('aggr', 'add')  # 邻域聚合方式
        super(MyConv, self).__init__(**kwargs)
        ...

    def forward(self, x, edge_index):
    	...
        return self.propagate(edge_index, **kwargs)

    def message(self, **kwargs):
    	...

2. Message 源码

2.1 MessagePassing初始化

def __init__(self, aggr: Optional[str] = "add",
           flow: str = "source_to_target", node_dim: int = -2,
           decomposed_layers: int = 1):

aggr: 邻域聚合方式,默认add,还可以是mean, max。
flow: 消息传递方向,默认从source_to_target,也可以设置为target_to_source。
node_dim: 定义沿着哪个维度进行消息传递,默认-2,因为-1是特征维度。

2.2 MessagePassing.propagate

 MessagePassing.propagate(edge_index, size=None, **kwargs)

progagate会依次调用messageaggregateupdate方法。如果edge_index是SparseTensor,会优先调用message_and_aggregate方法来代替messageaggregate方法。

edge_index:它有两种形式Tensor和SparseTensor。Tensor形式下的edge_index的shape是(2, N);SparseTensor则可以理解为稀疏矩阵的形式存储边信息。
size:当size为None的时候,默认邻接矩阵是方形[N, N]。如果是异构图(如bipartite图),图中的两类点的特征和index是相互独立的。通过传入size=(N, M),x=(x_N, x_M)时,propagate可以处理这种情况。
kwargs:图卷积计算过程的额外所需的信息,都可以通过kwargs传入。

2.3 MessagePassing.message()

这个方法在 flow=“source_to_target” 的设置下,计算了邻居节点 j 到中心节点 i 的消息。传给propagate()所有参数都可以传递给message(),而且传递给propagate()的tensors可以通过加上_i或_j的后缀来mapping到对应的节点

def message(self, x_j: Tensor) -> Tensor:
    return x_j

x_j:代表了邻居的特征,通过edge_index中邻居节点去索引对应位置的x得到

当edge_index的shape是(2, N_edges),x的shape是(N_nodes, N_features),则得到的x_j的shape是(N_edges, N_features)

例如:
edge_index:tensor([[1, 2, 3, 3], [0, 0, 0, 1]])
x:tensor([[0, 1], [2, 3], [4, 5], [6, 7]])
邻居节点 j 的index是edge_index的第一个元素[1,2,3,3],根据节点 j 的index [1, 2, 3, 3],去索引x对应的位置,则得到

x_j = x[index(j)]=x[[1,2,3,3]] = tensor([[2,3],[4,5],[6,7],[6,7]])

2.4 MessagePassing.aggregate(inputs, index, …)

这个方法实现了邻域的聚合,pytorch geometric通过scatter共实现了三种方式mean、 add、max。一般来说,比较通用的图算法,GCN、GraphSAGE、GAT都不需要自己再额外定义aggregate方法。

2.5 MessagePassing.update(aggr_out, …)

之前传入propagate的参数也都传入update。对应每个中心节点 i ,根据aggregate的邻域结果和传入propagate的参数中选择所需信息,更新节点 i 的embedding。

2.6 MessagePassing.message_and_aggregate(adj_t, …)

前面提到pytorch geometric中的边信息有Tensor和SparseTensor两种形式。
SparseTensor提供了矩阵存储形式,以稀疏矩阵方式存储,message_and_aggregate则提供了邻域聚合的矩阵计算方式〈不是所有的图卷积都可以用矩阵计算)。

当边是以SparseTensor存储的时候,propagate会优先去查找是否实现了message_and_aggregate如果已经实现了,就会调用message_and_aggregate来代替messageaggregate。如果没有实现, propagate需要将边信息转换为Tensor,然后再调用messageaggregate
message_and_aggregate是需要自己实现的,只有实现了它,才可以发挥矩阵计算的优势。

3 实例

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        self.lin = torch.nn.Linear(in_channels, out_channels)
    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)

你可能感兴趣的:(python项目,#学习记录,python,深度学习,机器学习)