图中的卷积计算通常被称为邻域聚合或者消息传递 (neighborhood aggregation or message passing). 定义 x i ( k − 1 ) ∈ R F \mathbf x^{(k-1)}_i \in R^{F} xi(k−1)∈RF 为节点 i i i 在第 ( k − 1 ) (k-1) (k−1) 层的特征, e j , i \mathbf e_{j,i} ej,i 表示节点 j j j 到 节点 i i i 的边特征,在 GNN 中消息传递可以表示为
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf x_{i}^{(k)} = \gamma^{(k)} \left(\mathbf x_{i}^{(k-1)}, \square_{j \in N(i)} \phi^{(k)} \left(\mathbf x_{i}^{(k-1)}, \mathbf x_{j}^{(k-1)}, \mathbf e_{j,i} \right) \right) xi(k)=γ(k)(xi(k−1),□j∈N(i)ϕ(k)(xi(k−1),xj(k−1),ej,i))
其中 □ \square □ 表示具有置换不变性并且可微的函数,例如 sum, mean, max 等, γ \gamma γ 和 ϕ \phi ϕ 表示可微函数。
在 PyTorch Gemetric 中,所有卷积算子都是由 MessagePassing
类派生而来,理解 MessagePasing
有助于我们理解 PyG 中消息传递的计算方式和编写自定义的卷积。在自定义卷积中,用户只需定义消息传递函数 ϕ \phi ϕ message()
, 节点更新函数 γ \gamma γ update()
以及聚合方式 aggr='add', aggr='mean'
或则 aggr=max
. 具体函数说明如下:
MessagePassing(aggr='add', flow='source_to_target', node_dim=-2)
定义聚合计算的方式 ('add', 'mean'
or max
) 以及消息的传递方向 (source_to_target
or target_to_source
). 在 PyG 中,中心节点为目标 target,邻域节点为源 source. node_dim
为消息聚合的维度MessagePassing.propagate(edge_index, size=None, **kwargs):
该函数接受边信息 edge_index
和其他额外的数据来执行消息传递并更新节点嵌入。MessagePassing.message(...):
该函数的作用是计算节点消息,就是公式中的函数 ϕ \phi ϕ . 如果 flow='source_to_target'
,那么消息将由邻域节点 j j j 传向中心节点 i i i ;如果 flow='target_to_source'
,消息则由中心节点 i i i 传向邻域节点 j j j . 传入参数的节点类型可以通过变量名后缀来确定,例如中心节点嵌入变量一般以 _i
为结尾,邻域节点嵌入变量以 x_j
为结尾MessagePassing.update(arr_out, ...):
该函数为节点嵌入的更新函数 γ \gamma γ , 输入参数为聚合函数 MessagePassing.aggregate
计算的结果为了更好的理解 PyG 中 MessagePassing
的计算过程,我们来分析一下源代码。
class MessagePassing(torch.nn.Module):
special_args: Set[str] = {
'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',
'size_i', 'size_j', 'ptr', 'index', 'dim_size'
}
def __init__(self, aggr: Optional[str] = "add",
flow: str = "source_to_target", node_dim: int = -2):
super(MessagePassing, self).__init__()
self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max', None]
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
self.node_dim = node_dim
self.inspector = Inspector(self)
self.inspector.inspect(self.message)
self.inspector.inspect(self.aggregate, pop_first=True)
self.inspector.inspect(self.message_and_aggregate, pop_first=True)
self.inspector.inspect(self.update, pop_first=True)
self.__user_args__ = self.inspector.keys(
['message', 'aggregate', 'update']).difference(self.special_args)
self.__fused_user_args__ = self.inspector.keys(
['message_and_aggregate', 'update']).difference(self.special_args)
# Support for "fused" message passing.
self.fuse = self.inspector.implements('message_and_aggregate')
# Support for GNNExplainer.
self.__explain__ = False
self.__edge_mask__ = None
在初始化函数中,MessagePassing
定义了一个 Inspector
. Inspector 的中文意思是检查员的意思,这个类的作用就是检查各个函数的输入参数,并保存到 Inspector
的参数列表字典中 Inspector.params
中。如果 message
的输入参数为 x_i, x_j
,那么Inspector.params['message']={'x_i': Parameter, 'x_j': Parameter}
(注:这里仅作示意,实际 Inspector.params['message']
类型为 OrderedDict
). Inspector.implements
检查函数是否实现.
MessagePasing
中最核心的是 propgate
函数,假设邻接矩阵 edge_index
的类型为 Torch.LongTensor
,消息由 edge_index[0]
传向 edge_index[1]
,代码实现如下
def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
# 为了简化问题,这里不讨论 edge_index 为 SparseTensor 的情况,感兴趣的可阅读 PyG 原始代码
size = self.__check_input__(edge_index, size)
coll_dict = self.__collect__(self.__user_args__, edge_index, size,
kwargs)
msg_kwargs = self.inspector.distribute('message', coll_dict)
out = self.message(**msg_kwargs)
aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
out = self.aggregate(out, **aggr_kwargs)
update_kwargs = self.inspector.distribute('update', coll_dict)
return self.update(out, **update_kwargs)
在这段代码中,首先是检查节点数量和用户自定义的输入变量,然后依次执行 message
, aggregate
和 update
函数。如果是自定义图卷积,一般会重写 message
和 update
,这一点随后再以 GCN 为例解释,这里首先来看一下 aggregate
的实现
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)
ptr
变量是针对邻接矩阵 edge_index
为 SparseTensor
的情况,此处暂且不论。inputs
为 message
计算得到的消息, index
就是待更新节点的索引,实际上就是 edge_index_i
. 聚合计算通过 scatter
函数实现。scatter
具体实现参考链接
下面以 GCN 为例,我们来看一下 MessagePassing
的计算过程。GCN 的计算公式如下
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ( i ) ⋅ deg ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=j∈N(i)∪{i}∑deg(i)⋅deg(j)1⋅(Θ⋅xj(k−1)),
实际计算工程可以分为下面几步:
代码如下
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
在 forward
函数中,首先是给节点边增加自循环。设输入变量如下
edge_index = torch.tensor([[0, 0, 2], [1, 2, 3]], dtype=torch.long)
x = torch.rand((4, 3))
conv = GCNConv(3, 8)
注意到默认消息传递方向为 source_to_target
,此时edge_index[0]=x_j
为 source, edge_index[1]=x_i
为 target. 在 GCN 中,第一步是增加节点的自循环,add_self_loops
计算前后变化如下
# before add_self_loops
# edge_index=
tensor([[0, 0, 2],
[1, 2, 3]])
# after add_self_loops
# edge_index=
tensor([[0, 0, 2, 0, 1, 2, 3],
[1, 2, 3, 0, 1, 2, 3]])
# norm=
tensor([0.7071, 0.7071, 0.5000, 1.0000, 0.5000, 0.5000, 0.5000]
此处的 propagate
的输出参数由 edge_index, x, norm
, edge_index
是 propagete
必须输入的参数,x, norm
为用户自定义参数。在 __collect__
会根据变量名称来收集 message
需要的输入参数。在 GCN 中,norm
保持不变,x
将被映射到 x_j
,并且经过 __lift__
函数,其值也会发生变化。__lift__
函数如下
def __lift__(self, src, edge_index, dim):
if isinstance(edge_index, Tensor):
index = edge_index[dim]
return src.index_select(self.node_dim, index)
在本例中,输入的特征 shape=[4, 8]
,经过 __lift__
后,节点特征 shape=[7, 8]
. 经过 message
计算后,就可以执行 aggregate
和 update
了。