[Scene Graph] 图神经网络的核心方法——Message Passing

GNN中的Message Passing方法解析
一、GNN中是如何实现特征学习的?

深度学习方法的兴起是从计算图像处理(Computer Vision)领域开始的。以卷积神经网络(CNN)为代表的方法会从邻近的像素中获取信息。这种方式对于结构化数据(structured data)十分有效,例如,图像和体素数据。但是,CNN的处理方式对于类似图(graph)数据则并不适用。对于一个图而言,类似图像像素的邻近关系是不存在的。考虑图中任意一个节点,我们几乎不可能始终找到8个有序邻近节点,因此类CNN卷积核将不再适用非结构数据(unstructured data)特征提取,这种非结构数据包括图(graph)和点集(set),例如点云数据(point cloud)和场景图(scene graph)。

那么对于图来说,GNN是如何学习到有效的特征呢?这就是这篇博文将要讨论的十分重要的方法——Message Passing。如果知道Message Passing准确的中文翻译的小伙伴可以给我留言,这篇博文中我将用缩写MP代表Message Passing。

大致来说,GNN与CNN在思想上是相通的。GNN也希望能够从“邻接”的元素中获取信息,并学习到固定的且可以固化在向量中的特征。因此,在字面意义上来说MP的思想就是,节点与节点之间可以相互传播(pass)信息(message),通过对于邻接节点信息的处理和原本节点特征的融合,来更新每一个点的特征,最后得到节点级别(node-level)或者图级别(graph-level)的序列化模式表达,这种表达可以完成节点分类,图分类,图生成等十分基本的任务。

二、如何用数学定义Message Passing过程?

小伙伴们不用害怕复杂的数学公式,只需要跟着思路走就可以,我会使用可以直观理解的变量命名方式的。首先,先在脑海中构建一个抽象的图结构 G \mathcal{G} G,这个图的节点集合为 V \mathcal{V} V,边的集合为 E \mathcal{E} E。重点讨论的变量我列在了下面:

u , u ∈ V u, u\in \mathcal{V} u,uV:我们重点讨论节点,可以是图中的任意一个节点。
{ v ∣ v ∈ N ( u ) } \{v| v \in N(u)\} {vvN(u)}:节点 u u u的邻接节点集合,与 u u u有有效连接的节点的集合。
x u x_{u} xu:节点 u u u的初始特征向量,如果是节点 v v v的特征则标记为 x v x_v xv
h u h_{u} hu:节点 u u u的隐层特征向量,如果是节点 v v v的隐层特征则标记为 h v h_v hv。若GNN中包含许多中间层,那么使用 h u l h^{l}_{u} hul来表述第 l l l层的节点 u u u的隐层特征。

MP的过程可以由下述一个公式概括,这里借用文献[1]中的表述:
h u l = F c o m b i n e ( h u l − 1 , F a g g r e g a t e ( h v l − 1 ) ) h^{l}_{u}=F_{combine}(h^{l-1}_u, F_{aggregate}(h^{l-1}_{v})) hul=Fcombine(hul1,Faggregate(hvl1))

简单来说 F a g g r e g a t e ( ⋅ ) F_{aggregate}(·) Faggregate()的作用就是从 u u u的邻接节点中汇集(aggregate)特征并加以变换,得到节点 u u u在本次得到的信息,因此,也可以写作:
m u l = F a g g r e g a t e ( h v l − 1 ) m^l_u=F_{aggregate}(h^{l-1}_{v}) mul=Faggregate(hvl1)

在节点 u u u获得到由周围节点提供的信息之后, F c o m b i n e ( ⋅ ) F_{combine}(·) Fcombine()将信息 m u l m^l_u mul和本节点上一层的信息整合(combine)到一起,更新在这次操作之后节点 u u u新的特征。

三、Message Passing之后的输出过程(不感兴趣可以略过,并不是本篇详细讨论的问题)

经过多层的MP过程之后,对于任意一个节点,都有一系列的隐层特征 { h u 0 , . . . , h u l } \{h^0_u,...,h^l_u\} {hu0,...,hul}。一般情况下可以通过sum,mean,max或者weighted sum这四种方式进行pooling操作,得到节点级别的输出。文章[2] 中的消融实验(ablation study)证明了learning weighted sum比sum和mean的池化方法更加好。不过,仅在[2]的数据集上测试过,我并不把这个现象看做普遍规律,还是需要平衡参数量和精度需求选择池化方式。对于图级别的输出,由于个人研究方向的限制,并没有很大的兴趣深究,以后再说啦。

四、常用的Message Passing方法

想要快速看懂大多数的GNN文章,掌握在第二节里介绍MP数学形式就基本足够了。之后就可以整理现不同GNN方法的优劣。更重要的是,也可以按照以下步骤来构建自己的MP算法。我自己使用的步骤如下:

  1. 注意图节点的连接方式,包括单向图、无向图(双向图)、匀质图(图中节点的性质是一致的)、异质图、稀疏图、稠密图等。构建图的邻接矩阵。
  2. F a g g r e g a t e F_{aggregate} Faggregate:获取信息的方法主要以邻近节点的特征为输入,通过sum,concatenate,attention等机制融合邻接节点的特征。通过线性(linear)或者非线性(一般使用mlp)变换得到message特征。
  3. F c o m b i n e F_{combine} Fcombine:将当前节点的特征与message融合,完成当前轮的MP迭代。简单的方式可以使用MLP来实现融合。另外,可以使用LSTM和GRU门控的方式保留部分之前节点的信息来更新当前节点的特征。
  4. 节点输出或者图输出。
五、Pytorch Geometric示例代码,理论和实践结合

Pytorch Geometric库中已经写好了MP算法的蓝图,而且成熟的GNN的卷积层都已经写在torch_geometric.nn.conv头文件里,可以随时查阅或扩展。

下面展示了Pytorch Geometric中的GCN[3]论文中MP算法的实现过程。先说明一下能看懂这段代码所需要的的基本知识:

  1. GCNConv这个类别是继承了MessagePassing父类的,其中的4个函数 __ init__, forward, message和update都是MessagePassing类中定义好且需要我们自己扩展的函数。先放一下完整的代码。
  2. __ init__函数用来声明并初始化变量的,可以通过aggr变量控制特征的融合方式,通过flow变量来控制信息的流动方向,包括“source_to_target”和“target_to_source”。
  3. forward函数就是函数的运行的主体,使用pytorch的小伙伴一定不会陌生。这里提一下的是,这个函数最后需要调用self.propagate(edge_index, x, args)函数来进行MP算法。其中edge_index是邻接矩阵的稀疏表达,x是节点的特征矩阵,args就是你想在MP算法中调用的其他变量。
  4. message函数,顾名思义,就是用来计算message的函数。它的输入是self.propagate函数传入的变量们,基本上这个函数就是可以自由发挥的地方了。
  5. update函数是用来更新节点特征的函数,也就是combine的过程了。由于我平时喜欢额外调用GRUcell来计算,因此这个函数没有深入研究过。

完整代码如下:

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
 
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels, out_channels)
 
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
 
        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
 
        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)
 
        # Step 3-5: Start propagating messages.
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
 
    def message(self, x_j, edge_index, size):
        # x_j has shape [E, out_channels]
 
        # Step 3: Normalize node features.
        row, col = edge_index
        deg = degree(row, size[0], dtype=x_j.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
 
        return norm.view(-1, 1) * x_j
 
    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
 
        # Step 5: Return new node embeddings.
        return aggr_out
参考文献

以下参考文献非正规格式,仅提供能查阅到文献的必要信息,刚写论文的小伙伴不要学习哦。
[1] V. Thost, J. Chen. Directed Acyclic Graph Neural Networks. ICLR 2021.
[2] D. Xu, Y. Zhu, et al. Scene Graph Generation by Iterative Message Passing. CVPR 2017.
[3] Kipf, Welling. Semi-Supervised Classification With Graph Convolutional Networks. ICLR 2017.

你可能感兴趣的:(Scene,Graph,神经网络,深度学习,机器学习,人工智能)