图神经网络学习6

基于图神经网络的图表征学习方法

图表征学习要求在输入节点属性、边和边的属性(如果有的话)得到一个向量作为图的表征,基于图表征进一步的我们可以做图的预测。本文以一个经典的同构网络GIN来举例。

网络实现

基于图同构网络的图表征学习主要包含以下两个过程

  1. 首先计算得到节点表征;
  2. 其次对图上各个节点的表征做图池化(Graph Pooling),或称为图读出(Graph Readout),得到图的表征(Graph Representation)。

首先是整个网络的框架,此模块首先采用GINNodeEmbedding模块对图上每一个节点做节点嵌入(Node Embedding),得到节点表征;然后对节点表征做图池化得到图的表征;最后用一层线性变换对图表征转换为对图的预测。代码实现如下:

import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool,
global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding

class GINGraphRepr(nn.Module):
    def __init__(self, num_tasks=1, num_layers=5, emb_dim=300,
                 residual=False, drop_ratio=0, JK='last', graph_pooling='sum'):
        """GIN Graph Pooling Module
        Args:
            num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表征的维度,dimension of graph representation).
            num_layers (int, optional): number of GINConv layers. Defaults to 5.
            emb_dim (int, optional): dimension of node embedding. Defaults to 300.
            residual (bool, optional): adding residual connection or not. Defaults to False.
            drop_ratio (float, optional): dropout rate. Defaults to 0.
            JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
            graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".

        Out:
            graph representation
        """
        super(GINGraphPooling, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
        
        self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=nn.Sequential(
                nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim),
                nn.Relu(), nn.Linear(emb_dim, 1)))
        elif graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")
        
        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
            
    def forward(self, batch_data):
        h_node = self.gnn_node(batched_data)
        h_graph = self.pool(h_node, batched_data.batch)
        output = self.graph_pred_linear(h_graph)
        
        if self.training:
            return output
        else:
            return torch.clamp(output, min=0, max=50)

现在构建GINNodeEmbedding模块
我们首先用AtomEncoder对其做嵌入得到第0层节点表征(稍后我们再对AtomEncoder做分析)。然后我们逐层计算节点表征,从第1层开始到第num_layers层,每一层节点表征的计算都以上一层的节点表征h_list[layer]、边edge_index和边的属性edge_attr为输入。需要注意的是,GINConv的层数越多,此节点嵌入模块的感受野(receptive field)越大结点i的表征最远能捕获到结点i的距离为num_layers的邻接节点的信息

import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F

# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
        """GIN Node Embedding Module"""

        super(GINNodeEmbedding, self).__init__()
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        # add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        # List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layers):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr

        # computing input node embedding
        h_list = [self.atom_encoder(x)]  # 先将类别型原子属性转化为原子表征
        for layer in range(self.num_layers):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        # Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers + 1):
                node_representation += h_list[layer]

        return node_representation

接下来我们来学习图同构网络的关键组件GINConv

GINConv–图同构卷积层

图同构卷积层的数学定义如下:
x i ′ = h Θ ( ( 1 + ϵ ) ⋅ x i + ∑ j ∈ N ( i ) x j ) \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) xi=hΘ(1+ϵ)xi+jN(i)xj
PyG中已经实现了此模块,我们可以通过torch_geometric.nn.GINConv来使用PyG定义好的图同构卷积层,然而该实现不支持存在边属性的图。在这里我们自己自定义一个支持边属性的GINConv模块

import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder


### GIN convolution along the graph structure
class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(GINConv, self).__init__(aggr = "add")

        self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
        self.eps = nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)
        
    def update(self, aggr_out):
        return aggr_out

理论分析

动机(Motivation)

新的图神经网络的设计大多基于经验性的直觉、启发式的方法和实验性的试错。人们对图神经网络的特性和局限性了解甚少,对图神经网络的表征能力学习的正式分析也很有限。

贡献与结论

  1. (理论上)图神经网络在区分图结构方面最高能达到与WL Test一样的能力。
  2. 确定了邻接节点聚合方法和图池化方法的应具备的条件,在这些条件下,所产生的图神经网络能达到与WL Test一样的能力。
  3. 分析出过去流行的图神经网络变体(如GCN和GraphSAGE)无法区分一些结构的图。
  4. 开发了一个简单的图神经网络模型–图同构网络(Graph Isomorphism Network, GIN),并证明其分辨图的同构性的能力与表示图的能力与WL Test相当。

背景:Weisfeiler-Lehman Test (WL Test)

图同构性测试

两个图是同构的,意思是两个图拥有一样的拓扑结构,也就是说,我们可以通过重新标记节点从一个图转换到另外一个图。Weisfeiler-Lehman 图的同构性测试算法,简称WL Test,是一种用于测试两个图是否同构的算法。

WL Test 的一维形式,类似于图神经网络中的邻接节点聚合。WL Test 1)迭代地聚合节点及其邻接节点的标签,然后 2)将聚合的标签散列(hash)成新标签,该过程形式化为下方的公示,
L u h ← hash ⁡ ( L u h − 1 + ∑ v ∈ N ( U ) L v h − 1 ) L^{h}_{u} \leftarrow \operatorname{hash}\left(L^{h-1}_{u} + \sum_{v \in \mathcal{N}(U)} L^{h-1}_{v}\right) LuhhashLuh1+vN(U)Lvh1
在上方的公示中, L u h L^{h}_{u} Luh表示节点 u u u的第 h h h次迭代的标签,第 0 0 0次迭代的标签为节点原始标签。

在迭代过程中,发现两个图之间的节点的标签不同时,就可以确定这两个图是非同构的。需要注意的是节点标签可能的取值只能是有限个数。

WL测试不能保证对所有图都有效,特别是对于具有高度对称性的图,如链式图、完全图、环图和星图,它会判断错误。

Weisfeiler-Lehman Graph Kernels 方法提出用WL子树核衡量图之间相似性。该方法使用WL Test不同迭代中的节点标签计数作为图的表征向量,它具有与WL Test相同的判别能力。直观地说,在WL Test的第 k k k次迭代中,一个节点的标签代表了以该节点为根的高度为 k k k的子树结构。

作业

  • 请画出下方图片中的6号、3号和5号节点的从1层到3层到WL子树。
    图神经网络学习6_第1张图片
    答案:
    图神经网络学习6_第2张图片
    本文内容来自datawhale开源课程
    参考链接:https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN

你可能感兴趣的:(神经网络)