图神经网络通用框架信息传递网络(MPNNs)

图神经网络通用框架信息传递网络(MPNNs)

  • 介绍
  • 机制
    • 理论
      • 信息传递阶段
      • 读取阶段
    • 实际案例
    • 代码
  • 第三方库

介绍

信息传递网络(Message Passing Neural Networks, MPNNs)是由Gilmer等人提出的一种图神经网络通用计算框架。原文以量子化学为例,根据原子的性质(对应节点特征)和分子的结构(对应边特征)预测了13种物理化学性质。查看论文原文请点击这里。

机制

理论

MPNN的前向传播包括两个阶段,第一个阶段称为message passing(信息传递)阶段,第二个阶段称为readout(读取)阶段。定义一张图 G = ( V , E ) G=(V,E) G=(V,E),其中 V V V是所有节点, E E E是所有边。

信息传递阶段

message passing阶段会执行多次信息传递过程。对于一个特定的节点v,我先给出公式。
m v t + 1 = ∑ w ∈ N ( v ) M t ( h v t , h w t , e v w ) (1) m_v^{t+1}=\sum_{w\in N(v)}M_t\left( h_v^{t},h_w^{t},e_{vw} \right)\tag{1} mvt+1=wN(v)Mt(hvt,hwt,evw)(1) h v t + 1 = U t ( h v t , m v t + 1 ) (2) h_v^{t+1}=U_t\left(h_v^{t},m_v^{t+1}\right)\tag{2} hvt+1=Ut(hvt,mvt+1)(2)
其中,在公式 ( 1 ) (1) (1)中, m v t + 1 m_v^{t+1} mvt+1是结点vt+1时间步所接收到的信息, N ( v ) N(v) N(v)是结点v的所有邻结点, h v t h_v^{t} hvt是结点vt时间步的特征向量, e v w e_{vw} evw是结点vw的边特征, M t M_t Mt是消息函数。该公式的意义是节点v收到的信息来源于节点v本身状态( h v t h_v^{t} hvt),周围的节点状态( h w t h_w^{t} hwt)和与之相连的边特征( e v w e_{vw} evw)。生成信息后,就需要对结点进行更新。

在公式 ( 2 ) (2) (2)中, U t U_t Ut是结点更新函数,该函数把原节点状态 h v t h_v^{t} hvt和信息 m v t + 1 m_v^{t+1} mvt+1作为输入,得到新的节点状态 h v t + 1 h_v^{t+1} hvt+1。熟悉RNN的同学可能会眼熟这个公式,这个更新函数和RNN里的更新函数是一样的。后面我们也可以看到,我们可以用GRU或LSTM来表示 U t U_t Ut

最后再强调一下时间步的概念。计算完一次 ( 1 ) (1) (1) ( 2 ) (2) (2)算一个时间步,因此如果时间步设为 T T T,上述两个公式会各运行 T T T次,最终得到的结果是 h v T h_v^{T} hvT

读取阶段

readout阶段使用读取函数 R R R计算基于整张图的特征向量,可以表示为
y ^ = R ( { h v T ∣ v ∈ G } ) (3) \hat{y}=R\left(\{h_v^T|v \in G \} \right)\tag{3} y^=R({hvTvG})(3)
其中, y ^ \hat{y} y^是最终的输出向量, R R R是读取函数,这个函数有两个要求:1、要可以求导。2、要满足置换不变性(结点的输入顺序不改变最终结果,这也是为了保证MPNN对图的同构有不变性)

实际案例

在MPNN的框架下,我们可以自定义消息函数、更新函数和读取函数,下面我举一个实际的案例,也是这篇文章所提及的门控图神经网络(Gated Graph Neural Networks, GG-NN)。这里,信息函数、结点更新函数和读取函数被定义为
M t ( h v t , h w t , e v w ) = A e v w h w t (4) M_t\left( h_v^{t},h_w^{t},e_{vw} \right)=A_{e_{vw}}h_w^t\tag{4} Mt(hvt,hwt,evw)=Aevwhwt(4) U t ( h v t , m v t + 1 ) = G R U ( h v t , m v t + 1 ) (5) U_t\left(h_v^{t},m_v^{t+1}\right)=GRU\left(h_v^{t},m_v^{t+1}\right)\tag{5} Ut(hvt,mvt+1)=GRU(hvt,mvt+1)(5) R = ∑ v ∈ V σ ( i ( h v ( T ) , h v 0 ) ) ⊙ ( j ( h v ( T ) ) ) (6) R=\sum_{v\in V}\sigma\left(i\left(h_v^{(T)},h_v^0\right)\right)\odot \left(j\left(h_v^{(T)}\right)\right)\tag{6} R=vVσ(i(hv(T),hv0))(j(hv(T)))(6)

消息函数 ( 4 ) (4) (4)中,矩阵 A e v w A_{e_{vw}} Aevw决定了图中的结点是如何与其他结点进行相互作用的,一条边对应一个矩阵。但是这个函数描述得有些笼统。GGNN文章中的公式更清晰一些,如下所示
a v ( t ) = A v : T [ h 1 ( t − 1 ) T , h 2 ( t − 1 ) T , . . . , h ∣ V ∣ ( t − 1 ) T ] T + b (7) a_v^{\left(t\right)}=A_{v:}^T\left[h_1^{(t-1)^T},h_2^{(t-1)^T},...,h_{|V|}^{(t-1)^T}\right]^T+b\tag{7} av(t)=Av:T[h1(t1)T,h2(t1)T,...,hV(t1)T]T+b(7)
其中, a v ( t ) a_v^{\left(t\right)} av(t)是结点vt时刻接收到的信息向量,和我们之前定义的 m v t + 1 m_v^{t+1} mvt+1是一样的,只是换了些字母。 h ( t − 1 ) h^{(t-1)} h(t1)表示节点在t-1个时间步的状态,因此 [ h 1 ( t − 1 ) T , h 2 ( t − 1 ) T , . . . , h ∣ V ∣ ( t − 1 ) T ] T \left[h_1^{(t-1)^T},h_2^{(t-1)^T},...,h_{|V|}^{(t-1)^T}\right]^T [h1(t1)T,h2(t1)T,...,hV(t1)T]T把每个结点的状态拼接在一个维度上,维度大小为 D ∣ V ∣ D|V| DV b b b是偏置项,至于 A v : A_{v:} Av:,我们先看下面这张图
图神经网络通用框架信息传递网络(MPNNs)_第1张图片
这里有一个边的特征矩阵 A A A,矩阵 A A A考虑了边的方向,因此它是由outin两个部分拼接而成,图中的不同字母代表了不同的相互作用类型(也可以视为每条边的特征,注意每一条边的特征维度都是 ( D , D ) (D, D) (D,D),而不是我们常见的一维向量,在实际应用中,如果边的初始特征维度不是 D D D,可以进行embedding或线性变换到 D × D D\times D D×D维,再reshape ( D , D ) (D, D) (D,D)),最终的维度是 ( D ∣ V ∣ , 2 D ∣ V ∣ ) (D|V|,2D|V|) (DV,2DV),其中 ∣ V ∣ |V| V是结点个数。有了矩阵 A A A之后,我们需要针对某一个结点选出“两列”(并非真正意义上的两列)。以2号结点作为v结点为例,我们在Outgoing EdgesIncoming Edges中分别找到2号结点,再把这两列拼接起来,得到一个维度是 ( D ∣ V ∣ , 2 D ) (D|V|,2D) (DV,2D)的矩阵 A v : A_{v:} Av:。将该矩阵的转置与所有节点的状态拼接成的列向量相乘,最终得到一个维度为 2 D 2D 2D的信息向量 a v ( t ) a_v^{\left(t\right)} av(t)。而对于无向图而言,只需要考虑一半的情况就行了。

结点更新函数 ( 5 ) (5) (5)GRU,对GRU不熟悉的同学可以看一下这方面的知识,在此就不再多做解释了。

读取函数 ( 6 ) (6) (6)看起来是较为复杂的,我们可以拆开来看。首先 ⊙ \odot 表示逐元素相乘, i i i j j j分别表示一个全连接神经网络,并且在 i i i的外面又套了一层sigmoid函数,用符号 σ \sigma σ表示。对于神经网络 i i i而言,输入是结点的初始状态和最终状态,因此输入维度是2 * in_dim,而对于神经网络 j j j而言,输入只有结点的最终状态,因此输入维度是in_dim。但是这两个神经网络的输出维度是一样的,这样才能逐元素相乘。再往深入一点讲,这里包含了self attention机制,就是在读取阶段要注意该节点最初的特征。

代码

我分别找到了Pytorch和Tensorflow的实现,以后有时间我会分析一下Pytorch版的实现过程。
Pytorch版
Tensorflow版(原作者)

第三方库

torch-geometric

你可能感兴趣的:(图神经网络)