信息传递网络(Message Passing Neural Networks, MPNNs)是由Gilmer等人提出的一种图神经网络通用计算框架。原文以量子化学为例,根据原子的性质(对应节点特征)和分子的结构(对应边特征)预测了13种物理化学性质。查看论文原文请点击这里。
MPNN的前向传播包括两个阶段,第一个阶段称为message passing(信息传递)
阶段,第二个阶段称为readout(读取)
阶段。定义一张图 G = ( V , E ) G=(V,E) G=(V,E),其中 V V V是所有节点, E E E是所有边。
message passing
阶段会执行多次信息传递过程。对于一个特定的节点v
,我先给出公式。
m v t + 1 = ∑ w ∈ N ( v ) M t ( h v t , h w t , e v w ) (1) m_v^{t+1}=\sum_{w\in N(v)}M_t\left( h_v^{t},h_w^{t},e_{vw} \right)\tag{1} mvt+1=w∈N(v)∑Mt(hvt,hwt,evw)(1) h v t + 1 = U t ( h v t , m v t + 1 ) (2) h_v^{t+1}=U_t\left(h_v^{t},m_v^{t+1}\right)\tag{2} hvt+1=Ut(hvt,mvt+1)(2)
其中,在公式 ( 1 ) (1) (1)中, m v t + 1 m_v^{t+1} mvt+1是结点v
在t+1
时间步所接收到的信息, N ( v ) N(v) N(v)是结点v
的所有邻结点, h v t h_v^{t} hvt是结点v
在t
时间步的特征向量, e v w e_{vw} evw是结点v
和w
的边特征, M t M_t Mt是消息函数。该公式的意义是节点v
收到的信息来源于节点v
本身状态( h v t h_v^{t} hvt),周围的节点状态( h w t h_w^{t} hwt)和与之相连的边特征( e v w e_{vw} evw)。生成信息后,就需要对结点进行更新。
在公式 ( 2 ) (2) (2)中, U t U_t Ut是结点更新函数,该函数把原节点状态 h v t h_v^{t} hvt和信息 m v t + 1 m_v^{t+1} mvt+1作为输入,得到新的节点状态 h v t + 1 h_v^{t+1} hvt+1。熟悉RNN的同学可能会眼熟这个公式,这个更新函数和RNN里的更新函数是一样的。后面我们也可以看到,我们可以用GRU或LSTM来表示 U t U_t Ut。
最后再强调一下时间步的概念。计算完一次 ( 1 ) (1) (1)和 ( 2 ) (2) (2)算一个时间步,因此如果时间步设为 T T T,上述两个公式会各运行 T T T次,最终得到的结果是 h v T h_v^{T} hvT。
readout
阶段使用读取函数 R R R计算基于整张图的特征向量,可以表示为
y ^ = R ( { h v T ∣ v ∈ G } ) (3) \hat{y}=R\left(\{h_v^T|v \in G \} \right)\tag{3} y^=R({hvT∣v∈G})(3)
其中, y ^ \hat{y} y^是最终的输出向量, R R R是读取函数,这个函数有两个要求:1、要可以求导。2、要满足置换不变性(结点的输入顺序不改变最终结果,这也是为了保证MPNN对图的同构有不变性)
在MPNN的框架下,我们可以自定义消息函数、更新函数和读取函数,下面我举一个实际的案例,也是这篇文章所提及的门控图神经网络(Gated Graph Neural Networks, GG-NN)。这里,信息函数、结点更新函数和读取函数被定义为
M t ( h v t , h w t , e v w ) = A e v w h w t (4) M_t\left( h_v^{t},h_w^{t},e_{vw} \right)=A_{e_{vw}}h_w^t\tag{4} Mt(hvt,hwt,evw)=Aevwhwt(4) U t ( h v t , m v t + 1 ) = G R U ( h v t , m v t + 1 ) (5) U_t\left(h_v^{t},m_v^{t+1}\right)=GRU\left(h_v^{t},m_v^{t+1}\right)\tag{5} Ut(hvt,mvt+1)=GRU(hvt,mvt+1)(5) R = ∑ v ∈ V σ ( i ( h v ( T ) , h v 0 ) ) ⊙ ( j ( h v ( T ) ) ) (6) R=\sum_{v\in V}\sigma\left(i\left(h_v^{(T)},h_v^0\right)\right)\odot \left(j\left(h_v^{(T)}\right)\right)\tag{6} R=v∈V∑σ(i(hv(T),hv0))⊙(j(hv(T)))(6)
消息函数 ( 4 ) (4) (4)中,矩阵 A e v w A_{e_{vw}} Aevw决定了图中的结点是如何与其他结点进行相互作用的,一条边对应一个矩阵。但是这个函数描述得有些笼统。GGNN文章中的公式更清晰一些,如下所示
a v ( t ) = A v : T [ h 1 ( t − 1 ) T , h 2 ( t − 1 ) T , . . . , h ∣ V ∣ ( t − 1 ) T ] T + b (7) a_v^{\left(t\right)}=A_{v:}^T\left[h_1^{(t-1)^T},h_2^{(t-1)^T},...,h_{|V|}^{(t-1)^T}\right]^T+b\tag{7} av(t)=Av:T[h1(t−1)T,h2(t−1)T,...,h∣V∣(t−1)T]T+b(7)
其中, a v ( t ) a_v^{\left(t\right)} av(t)是结点v
在t
时刻接收到的信息向量,和我们之前定义的 m v t + 1 m_v^{t+1} mvt+1是一样的,只是换了些字母。 h ( t − 1 ) h^{(t-1)} h(t−1)表示节点在t-1
个时间步的状态,因此 [ h 1 ( t − 1 ) T , h 2 ( t − 1 ) T , . . . , h ∣ V ∣ ( t − 1 ) T ] T \left[h_1^{(t-1)^T},h_2^{(t-1)^T},...,h_{|V|}^{(t-1)^T}\right]^T [h1(t−1)T,h2(t−1)T,...,h∣V∣(t−1)T]T把每个结点的状态拼接在一个维度上,维度大小为 D ∣ V ∣ D|V| D∣V∣, b b b是偏置项,至于 A v : A_{v:} Av:,我们先看下面这张图
这里有一个边的特征矩阵 A A A,矩阵 A A A考虑了边的方向,因此它是由out
和in
两个部分拼接而成,图中的不同字母代表了不同的相互作用类型(也可以视为每条边的特征,注意每一条边的特征维度都是 ( D , D ) (D, D) (D,D),而不是我们常见的一维向量,在实际应用中,如果边的初始特征维度不是 D D D,可以进行embedding
或线性变换到 D × D D\times D D×D维,再reshape
到 ( D , D ) (D, D) (D,D)),最终的维度是 ( D ∣ V ∣ , 2 D ∣ V ∣ ) (D|V|,2D|V|) (D∣V∣,2D∣V∣),其中 ∣ V ∣ |V| ∣V∣是结点个数。有了矩阵 A A A之后,我们需要针对某一个结点选出“两列”
(并非真正意义上的两列)。以2号结点作为v
结点为例,我们在Outgoing Edges
和Incoming Edges
中分别找到2号结点,再把这两列拼接起来,得到一个维度是 ( D ∣ V ∣ , 2 D ) (D|V|,2D) (D∣V∣,2D)的矩阵 A v : A_{v:} Av:。将该矩阵的转置与所有节点的状态拼接成的列向量相乘,最终得到一个维度为 2 D 2D 2D的信息向量 a v ( t ) a_v^{\left(t\right)} av(t)。而对于无向图而言,只需要考虑一半的情况就行了。
结点更新函数 ( 5 ) (5) (5)是GRU
,对GRU
不熟悉的同学可以看一下这方面的知识,在此就不再多做解释了。
读取函数 ( 6 ) (6) (6)看起来是较为复杂的,我们可以拆开来看。首先 ⊙ \odot ⊙表示逐元素相乘, i i i和 j j j分别表示一个全连接神经网络,并且在 i i i的外面又套了一层sigmoid
函数,用符号 σ \sigma σ表示。对于神经网络 i i i而言,输入是结点的初始状态和最终状态,因此输入维度是2 * in_dim
,而对于神经网络 j j j而言,输入只有结点的最终状态,因此输入维度是in_dim
。但是这两个神经网络的输出维度是一样的,这样才能逐元素相乘。再往深入一点讲,这里包含了self attention
机制,就是在读取阶段要注意
该节点最初的特征。
我分别找到了Pytorch和Tensorflow的实现,以后有时间我会分析一下Pytorch版的实现过程。
Pytorch版
Tensorflow版(原作者)
torch-geometric