GNN目前主流的做法是递归迭代聚合一阶邻域表征来更新节点表征,如[[GCN]]和 [[GraphSAGE]],但这些方法大多是经验主义,缺乏理论去理解GNN到底做了什么,还有什么改进空间。
How Powerful are Graph Neural Networks? 本文基于Weisfeiler-Lehman(WL) test 视角理论分析了GNN,包括:
GNN的目标是以图结构数据和节点特征作为输入,以学习到节点(或图)的embedding,用于分类任务。
基于邻域聚合的GNN可以拆分为以下三个模块:
**1.Aggregate:**聚合一阶邻域特征。
**2.Combine:**将邻居聚合的特征 与 当前节点特征合并, 以更新当前节点特征。
**3. Readout(可选):**如果是对graph分类,需要将graph中所有节点特征转变成graph特征。
Weisfeiler-Lehman(WL) test 是判断两个graph 是否具有相同的结构(同构)非常有效的方法,迭代进行以下操作得到节点新标签以判断同构性:
(初始化:将节点自身id作为标签)
聚合方案:聚合每个节点邻域和自身标签。
更新节点标签:使用Hash映射节点聚合标签,作为节点新标签。
**GNN迭代过程和WL test非常相似,**受这个启发,作者提出定理2,证明了 WL test是GNN能力的上限:如果任意G1, G2 是非同构图,如果存在GNN可以将其映射到两个不同的embedding,那么WL test 也能判断G1、G2非同构。
此方法来自于Weisfeiler-Lehman Graph Kernels。
WL Test 算法的一点局限性是,它只能判断两个图的相似性,无法衡量图之间的相似性。要衡量两个图的相似性,我们用WL Subtree Kernel方法。该方法的思想是用WL Test算法得到节点的多层的标签,然后我们可以分别统计图中各类标签出现的次数,存于一个向量,这个向量可以作为图的表征。两个图的这样的向量的内积,即可作为这两个图的相似性的估计。
作者提出定理3:如果GNN中Aggregate、Combine 和 Readout函数是单射,GNN可以和WL test一样强大
能实现判断图同构性的图神经网络需要满足,只在两个节点自身标签一样且它们的邻接节点一样时,图神经网络将这两个节点映射到相同的表征,即映射是单射性的。可重复集合(Multisets)指的是元素可重复的集合,元素在集合中没有顺序关系。 **一个节点的所有邻接节点是一个可重复集合,一个节点可以有重复的邻接节点,邻接节点没有顺序关系。**因此GIN模型中生成节点表征的方法遵循WL Test算法更新节点标签的过程。
在生成节点的表征后仍需要执行图池化(或称为图读出)操作得到图表征,最简单的图读出操作是做求和。由于每一层的节点表征都可能是重要的,因此在图同构网络中,不同层的节点表征在求和后被拼接,其数学定义如下,
h G = CONCAT ( READOUT ( { h v ( k ) ∣ v ∈ G } ) ∣ k = 0 , 1 , ⋯ , K ) h_{G} = \text{CONCAT}(\text{READOUT}\left(\{h_{v}^{(k)}|v\in G\}\right)|k=0,1,\cdots, K) hG=CONCAT(READOUT({hv(k)∣v∈G})∣k=0,1,⋯,K)
**采用拼接而不是相加的原因在于不同层节点的表征属于不同的特征空间。**未做严格的证明,这样得到的图的表示与WL Subtree Kernel得到的图的表征是等价的。
分析基于mean、max的aggregator的特性和不足。
节点v和v’为中心节点,通过聚合邻居特征生成embedding,分析不同aggregate设置下是否能区分不同的结构(如果能捕获不同结构,二者的embedding应该不一样)。
结论:由于mean和max-pooling 函数 不满足单射性,无法区分某些结构的图,故性能会比sum差一点。
sum可以学习精确的结构信息,mean偏向学习分布信息,max偏向学习有代表性的元素信息。
基于图同构网络的图表征学习主要包含以下两个过程:
在这里,我们将采用自顶向下的方式,来学习基于图同构模型(GIN)的图表征学习方法。我们首先关注如何基于节点表征计算得到图的表征,而忽略计算结点表征的方法。
此模块首先采用GINNodeEmbedding
模块对图上每一个节点做节点嵌入(Node Embedding),得到节点表征,然后对节点表征做图池化得到图的表征,最后用一层线性变换得到图的表征(graph representation)。代码实现如下:
import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding
class GINGraphRepr(nn.Module):
def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
"""GIN Graph Pooling Module
Args:
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表征的维度,dimension of graph representation).
num_layers (int, optional): number of GINConv layers. Defaults to 5.
emb_dim (int, optional): dimension of node embedding. Defaults to 300.
residual (bool, optional): adding residual connection or not. Defaults to False.
drop_ratio (float, optional): dropout rate. Defaults to 0.
JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
Out:
graph representation
"""
super(GINGraphPooling, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
# Pooling function to generate whole-graph embeddings
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=nn.Sequential(
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
elif graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
output = self.graph_pred_linear(h_graph)
if self.training:
return output
else:
# At inference time, relu is applied to output to ensure positivity
# 因为预测目标的取值范围就在 (0, 50] 内
return torch.clamp(output, min=0, max=50)
可以看到可选的基于结点表征计算得到图表征的方法有:
torch_geometric.nn.glob.global_add_pool
。torch_geometric.nn.glob.global_mean_pool
。torch_geometric.nn.glob.global_max_pool
。PyG中集成的所有的图池化的方法可见于Global Pooling Layers。
接下来我们将学习节点嵌入的方法。
此模块基于多层GINConv
实现结点嵌入的计算。此处我们先忽略GINConv
的实现。此模块得到的节点属性输入为类别型向量,我们首先用AtomEncoder
对其做嵌入得到第0
层节点表征(稍后我们再对AtomEncoder
做分析)。然后我们逐层计算节点表征,从第1
层开始到第num_layers
层,每一层节点表征的计算都以上一层的节点表征h_list[layer]
、边edge_index
和边的属性edge_attr
为输入。需要注意的是,GINConv
的层数越多,此模块的感受野(receptive field)越大,结点i
的表征最远能捕获到结点i
的距离为num_layers
的邻接节点的信息。
import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F
# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
"""GIN Node Embedding Module"""
super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
# add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
# List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr
# computing input node embedding
h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子表征
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
接下来我们来学习图同构网络的关键组件GINConv
。
GINConv
–图同构卷积层图同构卷积层的数学定义如下:
x i ′ = h Θ ( ( 1 + ϵ ) ⋅ x i + ∑ j ∈ N ( i ) x j ) \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) xi′=hΘ⎝⎛(1+ϵ)⋅xi+j∈N(i)∑xj⎠⎞
PyG中已经实现了此模块,我们可以通过torch_geometric.nn.GINConv
来使用PyG定义好的图同构卷积层,然而该实现不支持存在边属性的图。在这里我们自己自定义一个支持边属性的GINConv
模块。
由于输入的边属性为类别型,因此我们需要先将类别型边属性转换为边表征。我们定义的GINConv
模块遵循“消息传递、消息聚合、消息更新”这一过程。
self.propagate
的调用开始执行,该函数接收edge_index
, x
, edge_attr
此三个函数。edge_index
是形状为2,num_edges
的张量(tensor)。x_i
和x_j
张量,x_j
表示了消息传递的源节点,x_i
表示了消息传递的目标节点。message
函数被调用,此函数定义了从源节点传入到目标节点的消息,在这里要传递的消息是源节点表征与边表征之和的relu
。我们在super(GINConv, self).__init__(aggr = "add")
中定义了消息聚合方式为add
,那么传入给任一个目标节点的所有消息被求和得到aggr_out
,它是目标节点的中间过程的信息。GINConv
继承了MessagePassing
类,因此update
函数被调用。然而我们希望对节点做消息更新中加入目标节点自身的消息,因此在update
函数中我们只简单返回输入的aggr_out
。forward
函数中我们执行out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
实现消息的更新。import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder
### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr = "add")
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
在此篇文章中,我们学习了基于图同构网络(GIN)的图表征网络,为了得到图表征首先需要做节点表征,然后做图读出。GIN中节点表征的计算遵循WL Test算法中节点标签的更新方法,因此它的上界是WL Test算法。在图读出中,我们对所有的节点表征(加权,如果用Attention的话)求和,这会造成节点分布信息的丢失。
提出GlobalAttention的论文: “Gated Graph Sequence Neural Networks”
提出Set2Set的论文:“Order Matters: Sequence to sequence for sets”
PyG中集成的所有的图池化的方法:Global Pooling Layers
Weisfeiler-Lehman Test: Brendan L Douglas. The weisfeiler-lehman method and graph isomorphism testing. arXiv preprint arXiv:1101.5211, 2011.
Weisfeiler-Lehman Graph Kernels
知乎-风浪
Datawhale GNN 学习资料