PyG教程(6):自定义消息传递网络

一.前言

在上篇文章中主要介绍了GNN的消息传递机制,在PyG中提供了一个消息传递基类torch_geometric.nn.MessagePassing,它实现了消息传递的自动处理,继承该类就可以简单方便的构建自己的消息传播GNN。

本文的主要内容包括:MessagePassing类剖析、继承MessagePassing实现GAT。

二.如何自定义消息传递网络

要自定义GNN模型,首先需要继承MessagePassing类,然后重写如下方法:

  • message(...):构建要传递的消息;
  • aggregate(...):将从源节点传递过来的消息聚合到目标节点;
  • update(...):更新节点的消息。

上述方法并不是一定都要自定义,若MessagePassing类默认实现满足你的需求,则可以不重写。

2.1 构造函数

继承MessagePassing类后,在构造函数中可以通过super().__init__()方法来向基类MessagePassing传递参数,来指定消息传递的一些行为。MessagePassing类的初始化函数如下(该函数的参数便是可以通过子类向父类传递的参数):

def __init__(self, aggr: Optional[str] = "add",
             flow: str = "source_to_target", node_dim: int = -2,
             decomposed_layers: int = 1):

常用参数说明

参数 说明
aggr 消息传递聚合方式,常用的包括addmeanminmax等等。
flow 消息传播的方向,其中source_to_target表示从源节点到目标节点、target_to_source表示从目标节点到源节点
node_dim 传播的维度

2.2 propagate函数

在具体介绍消息传递的三个相关函数之前,首先先介绍propagate函数,该函数是消息传播的启动函数,调用该函数后以依次执行messageaggregateupdate方法来完成消息的传递聚合更新。该函数的声明如下所示:

propagate(self, edge_index: Adj, size: Size = None, **kwargs)

参数说明

参数 说明
edge_index 边索引
size 邻接矩阵的shape,若为None则表示方阵,若邻接矩阵不为方阵则需要通过其来传递邻接矩阵的shape
**kwargs 构建、聚合和更新消息所需的额外数据,都可以传入propagate函数,这些参数可以被消息传递的三个函数接收

该函数一般会传入edge_index和特征x

2.3 message函数

message函数是用来构建节点的消息的。传递给propagate函数的的tensor可以映射到中心节点和邻居节点上,只需要在相应变量名后加上_i_j即可,通常称_i中心节点,称_j邻居节点

示例:

self.propagate(edge_index, x=x):
    pass

def message(self, x_i, x_j, edge_index_i):
    pass

该例子中propagate函数接受两个参数edge_indexx,则message函数可以根据propagate函数中的两个参数构造自己的参数,上述message函数中构造的参数为:

  • x_i:中心节点的特征向量组成的矩阵,注意该矩阵与图节点的特征矩阵是不同的;
  • x_j:邻居节点的特征向量组成的矩阵;
  • edge_index_i:中心节点的索引。

注意,若flow='source_to_target',则消息将由邻居节点传向中心节点,若flow='target_to_source'则消息将从中心节点传向邻居节点,默认为第一种情况。

2.4 aggregate函数

消息聚合函数aggregate用来聚合来自邻居的消息,常用的包括addsummeanmax等,可以通过super().__init__()中的参数aggr来设定。该函数的第一个参数为message函数的输出(返回值)

2.5 update函数

update函数用来更新节点的消息,aggregate函数的输出(返回值)作为该函数的第一个参数

三.GAT实战

本节通过继承MessagePassing类来构建图注意力网络GAT。

3.1 GAT的消息传递机制

GAT的消息传递公式如下所示:
h i ( l + 1 ) = ∑ j ∈ N ( i ) α i , j W ( l ) h j ( l ) α i j l = s o f t m a x i ( e i j l ) e i j l = L e a k y R e L U ( a ⃗ T [ W h i ( l ) ∥ W h j ( l ) ] ) \begin{aligned} h_i^{(l+1)} & = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} \\ \alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l})\\ e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i}^{(l)} \| W h_{j}^{(l)}]\right)\end{aligned} hi(l+1)αijleijl=jN(i)αi,jW(l)hj(l)=softmaxi(eijl)=LeakyReLU(a T[Whi(l)Whj(l)])
其中 h i ( l ) , h j ( l ) h_i^{(l)},h_j^{(l)} hi(l),hj(l)分别表示节点 i i i和节点 j j j在第 l l l层的特征向量。从上面的公式可以看出,在聚合邻居消息时,需要首先计算邻居节点到中心节点的注意力权重,然后进行加权求和。

3.2 具体实现

根据3.1节中的消息传递机制,GAT卷积层的实现如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax, add_remaining_self_loops


class GATConv(MessagePassing):
    def __init__(self, in_feats, out_feats, alpha, drop_prob=0.0):
        super().__init__(aggr="add")
        self.drop_prob = drop_prob
        self.lin = nn.Linear(in_feats, out_feats, bias=False)
        self.a = nn.Parameter(torch.zeros(size=(2*out_feats, 1)))
        self.leakrelu = nn.LeakyReLU(alpha)
        nn.init.xavier_uniform_(self.a)

    def forward(self, x, edge_index):
        edge_index, _ = add_remaining_self_loops(edge_index)
        # 计算 Wh
        h = self.lin(x)
        # 启动消息传播
        h_prime = self.propagate(edge_index, x=h)
        return h_prime

    def message(self, x_i, x_j, edge_index_i):
        # 计算a(Wh_i || wh_j)
        e = torch.matmul((torch.cat([x_i, x_j], dim=-1)), self.a)
        e = self.leakrelu(e)
        alpha = softmax(e, edge_index_i)
        alpha = F.dropout(alpha, self.drop_prob, self.training)
        return x_j * alpha


if __name__ == "__main__":
    conv = GATConv(in_feats=3, out_feats=3, alpha=0.2)
    x = torch.rand(4, 3)
    edge_index = torch.tensor(
        [[0, 1, 1, 2, 0, 2, 0, 3], [1, 0, 2, 1, 2, 0, 3, 0]], dtype=torch.long)
    x = conv(x, edge_index)
    print(x.shape)

在上述实现过程中,在message函数中(消息构建阶段)已经把注意力权重计算好了,因此后续聚合过程中只需要对邻居的权重求和即可,这可以通过 super().__init__(aggr="add")来实现。而消息更新也没有其它特别的操作,因此不需要自定义,按默认的即可。

通过上述的GAT卷积层,便可以构造一个GAT模型。

四.结语

参考资料:

  • CREATING MESSAGE PASSING NETWORKS
  • MessagePassing

PyG中实现自己的消息传递图神经网络是非常方便的,当然本文的介绍并不完善,后续可能会由额外的补充。

你可能感兴趣的:(图神经网络框架,GNN,人工智能,深度学习)