图表示学习在图任务占据了十分重要的位置。我们一般在先图节点的表征,然后再通过图池化的方法。图表征学习要求在输 入节点属性、边和边的属性得到一个向量作为图的节点表征。在图节点的表征的基础上,可以进一步做图的预测,比如图的同构性等。图表征网络是当前最经典的图表征学习网络是GIN(图同构网络)。
GIN具体请参考How Powerful are Graph Neural Networks?。
import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F
class GINNodeEmbedding(torch.nn.Module):
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
我们把输入到此节点嵌入模块的节点属性为类别型向量,用AtomEncoder 对其做嵌入得到第一层节点表征,然后再逐层计算节点表征。值得注意的是,GIN的层数越多,此节点嵌入模块的感受野也就越大,中心结点的表征最远能捕获距离更远邻接节点的信息。
我们也可以通过继承GINConv来进行构建函数。其图同构卷积层的数学定义如下:
x i ′ = h Θ ( ( 1 + ϵ ) ⋅ x i + ∑ j ∈ N ( i ) x j ) \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) xi′=hΘ((1+ϵ)⋅xi+∑j∈N(i)xj)
class GINConv(MessagePassing):
def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
**kwargs):
kwargs.setdefault('aggr', 'add')
super(GINConv, self).__init__(**kwargs)
self.nn = nn
self.initial_eps = eps
if train_eps:
self.eps = torch.nn.Parameter(torch.Tensor([eps]))
else:
self.register_buffer('eps', torch.Tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
reset(self.nn)
self.eps.data.fill_(self.initial_eps)
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
size: Size = None) -> Tensor:
if isinstance(x, Tensor):
x: OptPairTensor = (x, x)
# propagate_type: (x: OptPairTensor)
out = self.propagate(edge_index, x=x, size=size)
x_r = x[1]
if x_r is not None:
out += (1 + self.eps) * x_r
return self.nn(out)
def message(self, x_j: Tensor) -> Tensor:
return x_j
def message_and_aggregate(self, adj_t: SparseTensor,
x: OptPairTensor) -> Tensor:
adj_t = adj_t.set_value(None, layout=None)
return matmul(adj_t, x[0], reduce=self.aggr)
def __repr__(self):
return '{}(nn={})'.format(self.__class__.__name__, self.nn)
WL Test是一种用于测试两个图是否同构的算法,是图的同构性测试算法。但是WL Test也有很大的一个缺点,不能在具有高度对称的图上其作用,容易产生误判。
WL Graph Kernels方法提出用WL子树的方法去测试图之间的相似性。WL子树图如下所示:
请画出下方图片中的6号、3号和5号节点的从1层到3层到WL子树。
[1] https://github.com/datawhalechina/team-learning-nlp
[2] https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html
[3] How Powerful are Graph Neural Networks?