在上篇文章中主要介绍了GNN的消息传递机制,在PyG中提供了一个消息传递基类torch_geometric.nn.MessagePassing
,它实现了消息传递的自动处理,继承该类就可以简单方便的构建自己的消息传播GNN。
本文的主要内容包括:MessagePassing
类剖析、继承MessagePassing
实现GAT。
要自定义GNN模型,首先需要继承MessagePassing
类,然后重写如下方法:
message(...)
:构建要传递的消息;aggregate(...)
:将从源节点传递过来的消息聚合到目标节点;update(...)
:更新节点的消息。上述方法并不是一定都要自定义,若
MessagePassing
类默认实现满足你的需求,则可以不重写。
继承MessagePassing
类后,在构造函数中可以通过super().__init__()
方法来向基类MessagePassing
传递参数,来指定消息传递的一些行为。MessagePassing
类的初始化函数如下(该函数的参数便是可以通过子类向父类传递的参数):
def __init__(self, aggr: Optional[str] = "add",
flow: str = "source_to_target", node_dim: int = -2,
decomposed_layers: int = 1):
常用参数说明:
参数 | 说明 |
---|---|
aggr |
消息传递聚合方式,常用的包括add 、mean 、min 、max 等等。 |
flow |
消息传播的方向,其中source_to_target 表示从源节点到目标节点、target_to_source 表示从目标节点到源节点 |
node_dim |
传播的维度 |
propagate
函数在具体介绍消息传递的三个相关函数之前,首先先介绍propagate
函数,该函数是消息传播的启动函数,调用该函数后以依次执行message
、aggregate
和update
方法来完成消息的传递、聚合和更新。该函数的声明如下所示:
propagate(self, edge_index: Adj, size: Size = None, **kwargs)
参数说明:
参数 | 说明 |
---|---|
edge_index |
边索引 |
size |
邻接矩阵的shape,若为None 则表示方阵,若邻接矩阵不为方阵则需要通过其来传递邻接矩阵的shape |
**kwargs |
构建、聚合和更新消息所需的额外数据,都可以传入propagate 函数,这些参数可以被消息传递的三个函数接收。 |
该函数一般会传入
edge_index
和特征x
。
message
函数message
函数是用来构建节点的消息的。传递给propagate
函数的的tensor
可以映射到中心节点和邻居节点上,只需要在相应变量名后加上_i
或_j
即可,通常称_i
为中心节点,称_j
为邻居节点。
示例:
self.propagate(edge_index, x=x):
pass
def message(self, x_i, x_j, edge_index_i):
pass
该例子中propagate
函数接受两个参数edge_index
和x
,则message
函数可以根据propagate
函数中的两个参数构造自己的参数,上述message
函数中构造的参数为:
x_i
:中心节点的特征向量组成的矩阵,注意该矩阵与图节点的特征矩阵是不同的;x_j
:邻居节点的特征向量组成的矩阵;edge_index_i
:中心节点的索引。注意,若
flow='source_to_target'
,则消息将由邻居节点传向中心节点,若flow='target_to_source'
则消息将从中心节点传向邻居节点,默认为第一种情况。
aggregate
函数消息聚合函数aggregate
用来聚合来自邻居的消息,常用的包括add
、sum
、mean
和max
等,可以通过super().__init__()
中的参数aggr
来设定。该函数的第一个参数为message
函数的输出(返回值)。
update
函数update
函数用来更新节点的消息,aggregate
函数的输出(返回值)作为该函数的第一个参数。
本节通过继承MessagePassing
类来构建图注意力网络GAT。
GAT的消息传递公式如下所示:
h i ( l + 1 ) = ∑ j ∈ N ( i ) α i , j W ( l ) h j ( l ) α i j l = s o f t m a x i ( e i j l ) e i j l = L e a k y R e L U ( a ⃗ T [ W h i ( l ) ∥ W h j ( l ) ] ) \begin{aligned} h_i^{(l+1)} & = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} \\ \alpha_{ij}^{l} & = \mathrm{softmax_i} (e_{ij}^{l})\\ e_{ij}^{l} & = \mathrm{LeakyReLU}\left(\vec{a}^T [W h_{i}^{(l)} \| W h_{j}^{(l)}]\right)\end{aligned} hi(l+1)αijleijl=j∈N(i)∑αi,jW(l)hj(l)=softmaxi(eijl)=LeakyReLU(aT[Whi(l)∥Whj(l)])
其中 h i ( l ) , h j ( l ) h_i^{(l)},h_j^{(l)} hi(l),hj(l)分别表示节点 i i i和节点 j j j在第 l l l层的特征向量。从上面的公式可以看出,在聚合邻居消息时,需要首先计算邻居节点到中心节点的注意力权重,然后进行加权求和。
根据3.1节中的消息传递机制,GAT卷积层的实现如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax, add_remaining_self_loops
class GATConv(MessagePassing):
def __init__(self, in_feats, out_feats, alpha, drop_prob=0.0):
super().__init__(aggr="add")
self.drop_prob = drop_prob
self.lin = nn.Linear(in_feats, out_feats, bias=False)
self.a = nn.Parameter(torch.zeros(size=(2*out_feats, 1)))
self.leakrelu = nn.LeakyReLU(alpha)
nn.init.xavier_uniform_(self.a)
def forward(self, x, edge_index):
edge_index, _ = add_remaining_self_loops(edge_index)
# 计算 Wh
h = self.lin(x)
# 启动消息传播
h_prime = self.propagate(edge_index, x=h)
return h_prime
def message(self, x_i, x_j, edge_index_i):
# 计算a(Wh_i || wh_j)
e = torch.matmul((torch.cat([x_i, x_j], dim=-1)), self.a)
e = self.leakrelu(e)
alpha = softmax(e, edge_index_i)
alpha = F.dropout(alpha, self.drop_prob, self.training)
return x_j * alpha
if __name__ == "__main__":
conv = GATConv(in_feats=3, out_feats=3, alpha=0.2)
x = torch.rand(4, 3)
edge_index = torch.tensor(
[[0, 1, 1, 2, 0, 2, 0, 3], [1, 0, 2, 1, 2, 0, 3, 0]], dtype=torch.long)
x = conv(x, edge_index)
print(x.shape)
在上述实现过程中,在message
函数中(消息构建阶段)已经把注意力权重计算好了,因此后续聚合过程中只需要对邻居的权重求和即可,这可以通过 super().__init__(aggr="add")
来实现。而消息更新也没有其它特别的操作,因此不需要自定义,按默认的即可。
通过上述的GAT卷积层,便可以构造一个GAT模型。
参考资料:
PyG中实现自己的消息传递图神经网络是非常方便的,当然本文的介绍并不完善,后续可能会由额外的补充。