这是一篇关于DGL消息传递的读后感。
这里, x v ∈ R d 1 x_v\in\mathbb{R}^{d_1} xv∈Rd1是节点 v v v的特征, w e ∈ R d 2 w_e\in\mathbb{R}^{d_2} we∈Rd2是边 ( u , v ) (u,v) (u,v)的特征。这个t+1时刻的消息传递框架定义如下的逐边(node-wise)和逐边的计算:
Edge-wise : m e t + 1 = ϕ ( x v t , x u t , w e t ) , ( u , v , e ) ∈ ε Node-wise : x v t + 1 = ψ ( x v t , ρ ( { m e t + 1 : ( u , v , e ) ∈ ε } ) ) \text{Edge-wise}:m_e^{t+1}=\phi(x_v^{t},x_u^{t},w_e^{t}),\ (u,v,e)\in\varepsilon \\ \text{Node-wise}:x_v^{t+1}=\psi(x_v^{t},\rho{\left(\left\{m_e^{t+1}:(u,v,e)\in \varepsilon\right\}\right)}) Edge-wise:met+1=ϕ(xvt,xut,wet), (u,v,e)∈εNode-wise:xvt+1=ψ(xvt,ρ({met+1:(u,v,e)∈ε}))
在上面的等式中, ϕ \phi ϕ是定义在每条边上的消息函数,它通过结合边特征和两个节点来产生message; ψ \psi ψ是定义在每个节点上的更新函数,通过使用reduce functaion
ρ \rho ρ去聚合(aggregating)输入信息,从而更新此节点。
Note:上面对Message Passing Paradigm的解释包含了很多信息。接下来我尝试通过下面这个图来解释这个计算框架:
在更新图节点的特征值,我们通常都是根据此节点的邻居节点更新的。假设我们想要更新节点6。那么,我们首先要知道的就是节点6和其邻居节点(2, 3, 7, 8, 9)通过相连的边传递的是什么信息,是相加吗?还是其它计算。此时这里的信息就是利用消息函数去完成的,消息函数通过 ( u , v , e ) (u,v,e) (u,v,e)在每条边上都生成了这个message,待会就利用这些信息去实施更新。
要更新节点6了,我们要知道节点6的信息更新只能通过边 a 36 , a 26 , a 76 , a 68 , a 69 a_{36},a_{26},a_{76},a_{68},a_{69} a36,a26,a76,a68,a69(对于无向图来说,节点的顺序不重要;有向图的就集中一个方向就行)上的信息去更新。这时候reduce function
ρ \rho ρ非常重要,它就是根据将节点6的邻居节点上的边信息收集起来,然后进行操作,可以是直接相加、相减、求和、平均、求权重和等等。那么最后,就利用 ψ \psi ψ去更新节点6的值。如果没听懂,我很抱歉,或许接下来的内容能让你重新了解。如果听懂了,我感到很欣慰,接下来的内容能让你更深刻的理解。
接下来分别介绍message function,reduce function和update function。
message function:由于消息函数会在边上产生信息,那么,它需要一个edges
参数。这个edges
有三个成员,分别是src, dst, data。能偶用来访问边上源节点的特征,目标节点的特征和边本身的特征。例子如下,假设我们将边上src的hu特征和dst上的hv特征相加,然后保存到he上。
def message_func(edges): # 消息函数的参数为edge
return {'he':edges.src['hu']+edges.dst['hv']}
当然,dgl库本身也有处理这方面的内置函数:dgl.function.u_add_v('hu', 'hv', 'he')
这里的u_add_v就表明了把源节点的特征和目标节点上的特征相加。
reduce function:需要有一个nodes
参数。它有一个mailbox
成员,能够用来访问这个节点收到的message。就像最开始讲的,一个节点只能够收到来自于邻居节点上的信息。所以它这个mailbox
就存储了这些信息。所以,如果我们想把mailbox
收到的message相加,然后存储到h里的话,也很简单:
import torch
def reduce_func(nodes):
return {'h':torch.sum(nodes.mailbox['m'], dim=1)}# 这里之所以有'm',就是消息存储到'm'这个键值上了,就像这里的'h'
当然,dgl也准备了内建函数dgl.function.sum('m', h)
。
update function:也需要一个参数nodes
参数。它通常在最后一步去结合reduce function
聚合的结果和目标节点的特征去更新目标节点的特征。
update_all()
:它是一个高阶函数,融合了消息的产生、聚合和节点的更新。所以,它需要三个参数:message function,reduce function和update function。也可以在update_all()
函数外调用更新函数而不用在update特别指定。DGL推荐这种在update_all()
外定义更新函数,因为为了代码的简洁,update function一般写成纯张量操作。具体实现例子如下,表达式为:将源节点j的特征ft_j
与相连的所有边a的特征a_ij
逐个相乘后再求和,最后将和乘以2。
final_ft i = 2 ∗ ∑ j ∈ N ( i ) ( f t j ∗ a i j ) \text{final\_ft}_{i}=2*\sum_{j\in{N(i)}}(ft_j*a_{ij}) final_fti=2∗j∈N(i)∑(ftj∗aij)
def update_all_example(graph):
"""
注意,这里的特征名称是一开始都设置好的。这个图本身包含了:
graph.ndata[ft]
graph.edata['a']
而m是临时用来存储message的
"""
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
final_ft = graph.ndata['ft'] * 2
return final_ft
dgl.function
里实现了很多了message function和reduce function。
接下来,介绍一个重点,update_all
函数。使用这个函数,将极大简化代码。
此处不再赘述。DGL库在message passing教程中接下来的内容都是关于如何优化使用和在不同场景下的使用,核心并没有改变。接下来,我尝试讲解DGL例子中使用Message Passing来构造GCN的例子,来更加清晰的使用Message Passing。
在这篇文章中,节点的更新过程为:
Z = D ~ − 1 2 A ~ D ~ − 1 2 X Θ ( 2 ) Z=\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}X\Theta \ \ \ \ (2) Z=D −21A D −21XΘ (2)
这里 A ~ = A + I \widetilde{A}=A+I A =A+I,也就是邻接矩阵增加了自环。主要,在实践中,已经有自环的节点无需再增加自环。 D ~ i i = ∑ j A ~ i j \widetilde{D}_{ii}=\sum_{j}\widetilde{A}_{ij} D ii=∑jA ij度矩阵是增加自环之后求得的。
g = dgl.remove_self_loop(g) # 增加自环时,首先去除原本的自环
g = dgl.add_self_loop(g)
degs = g.in_degrees().float() #无向图中,入度和出度是相同的。
norm = torch.pow(degs, -0.5)
接下来,我们需要将节点更新过程(2)用Message Passing来表达。首先,我们要知道Message Passing的思想就是在目标节点上求得edges上的信息,然后聚合起来更新目标节点。先给出最终表达式,有个目的性,然后再一步步推导:
x i k = ∑ j ∈ N ( i ) ∪ ( i ) 1 deg ( i ) deg ( j ) x j k − 1 Θ ( 3 ) x_i^{k}=\sum_{j\in N(i)\cup{(i)}}\frac{1}{\sqrt{\deg(i)}\sqrt{\text{deg}(j)}}x_j^{k-1}\Theta \ \ \ (3) xik=j∈N(i)∪(i)∑deg(i)deg(j)1xjk−1Θ (3)
以上的表达式说明,在更新第 k k k层的第 i i i个节点特征时,将 k − 1 k-1 k−1层第 i i i个节点特征与其邻居节点 j j j特征进行 Θ \Theta Θ转换、度的标准化,最后求和更新。这很符合图卷积的思想:将邻居节点的信息结合起来,更新目标节点。接下来将公式(2)进行分解到公式(3)。
我们一步步看公式(2):
D ~ − 1 2 A ~ = [ 1 deg ( 1 ) 0 ⋯ 0 0 1 deg ( 2 ) ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ 0 0 0 ⋯ 1 deg ( n ) ] ∗ A ~ \widetilde{D}^{-\frac{1}{2}}\widetilde{A} = \begin{bmatrix} \frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\ 0 & \frac{1}{\deg{(2)}} &\cdots &0\\ 0 & 0 & \cdots& 0\\ \vdots & \vdots &\ddots &0 \\ 0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}} \end{bmatrix}*\widetilde{A} D −21A =⎣⎢⎢⎢⎢⎢⎢⎢⎡deg(1)100⋮00deg(2)10⋮0⋯⋯⋯⋱⋯0000deg(n)1⎦⎥⎥⎥⎥⎥⎥⎥⎤∗A
Note: 1 deg ( i ) \frac{1}{\sqrt{\deg(i)}} deg(i)1是节点 i i i的度的标准化。
这里, D ~ − 1 2 A ~ \widetilde{D}^{-\frac{1}{2}}\widetilde{A} D −21A 就相当于将 A ~ \widetilde{A} A 的第 i i i行的值乘以节点 i i i的度。要进行下一步操作时,我们首先要搞清楚邻接矩阵(此文章增加了自环,但我们仍然用邻接矩阵称呼它) A ~ \widetilde{A} A 的意义。如下,
{ a i j = 1 当 节 点 j 是 节 点 i 的 邻 居 节 点 , 那 么 第 i 行 的 第 j 列 为 1 a i j = 0 其 它 \begin{cases} a_{ij}=1 & 当节点j是节点i的邻居节点,那么第i行的第j列为1 \\ a_{ij}=0 & 其它\\ \end{cases} {aij=1aij=0当节点j是节点i的邻居节点,那么第i行的第j列为1其它
我们令 D ~ − 1 2 A ~ D ~ − 1 2 = M \widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}=M D −21A D −21=M
D ~ − 1 2 A ~ D ~ − 1 2 = [ 1 deg ( 1 ) 0 ⋯ 0 0 1 deg ( 2 ) ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ 0 0 0 ⋯ 1 deg ( n ) ] ∗ A ~ ∗ [ 1 deg ( 1 ) 0 ⋯ 0 0 1 deg ( 2 ) ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ 0 0 0 ⋯ 1 deg ( n ) ] [ a 11 1 deg ( 1 ) ∗ 1 deg ( 1 ) a 12 1 deg ( 1 ) 1 deg ( 2 ) ⋯ a 1 n 1 deg ( 1 ) ∗ 1 deg ( n ) a 21 1 deg ( 2 ) ∗ 1 deg ( 1 ) a 22 1 deg ( 2 ) 1 deg ( 2 ) ⋯ a 2 n 1 deg ( 2 ) ∗ 1 deg ( n ) ⋮ ⋮ ⋱ ⋮ a n 1 1 deg ( n ) ∗ 1 deg ( 1 ) a n 2 1 deg ( n ) 1 deg ( 2 ) ⋯ a n n 1 deg ( n ) ∗ 1 deg ( n ) ] \widetilde{D}^{-\frac{1}{2}}\widetilde{A} \widetilde{D}^{-\frac{1}{2}}=\begin{bmatrix} \frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\ 0 & \frac{1}{\deg{(2)}} &\cdots &0\\ 0 & 0 & \cdots& 0\\ \vdots & \vdots &\ddots &0 \\ 0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}} \end{bmatrix}*\widetilde{A}*\begin{bmatrix} \frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\ 0 & \frac{1}{\deg{(2)}} &\cdots &0\\ 0 & 0 & \cdots& 0\\ \vdots & \vdots &\ddots &0 \\ 0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}} \end{bmatrix} \\ \begin{bmatrix}a_{11}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{12}\frac{1}{\sqrt{\deg{(1)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{1n}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(n)}}} \\a_{21}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{22}\frac{1}{\sqrt{\deg{(2)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{2n}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(n)}}}\\ \vdots & \vdots &\ddots &\vdots \\a_{n1}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{n2}\frac{1}{\sqrt{\deg{(n)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{nn}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(n)}}} \end{bmatrix} D −21A D −21=⎣⎢⎢⎢⎢⎢⎢⎢⎡deg(1)100⋮00deg(2)10⋮0⋯⋯⋯⋱⋯0000deg(n)1⎦⎥⎥⎥⎥⎥⎥⎥⎤∗A ∗⎣⎢⎢⎢⎢⎢⎢⎢⎡deg(1)100⋮00deg(2)10⋮0⋯⋯⋯⋱⋯0000deg(n)1⎦⎥⎥⎥⎥⎥⎥⎥⎤⎣⎢⎢⎢⎢⎢⎡a11deg(1)1∗deg(1)1a21deg(2)1∗deg(1)1⋮an1deg(n)1∗deg(1)1a12deg(1)1deg(2)1a22deg(2)1deg(2)1⋮an2deg(n)1deg(2)1⋯⋯⋱⋯a1ndeg(1)1∗deg(n)1a2ndeg(2)1∗deg(n)1⋮anndeg(n)1∗deg(n)1⎦⎥⎥⎥⎥⎥⎤
大家看到这里应该比较清楚了,这里的系数 a i j a_{ij} aij只能为0或者1,并且取决于 j j j是否为 i i i的邻居节点。 X Θ X\Theta XΘ仅仅只是将特征 X X X进行特征映射了,仅仅只是改变X的列维度。那么,接下来 M X MX MX的值就意味着矩阵M使用系数去选取X的值,这样就实现了选取邻居节点特征的意义。(害,其实我刚看到公式(2)的时候,内心是崩溃的,它怎么就实现抽取邻居节点特征的卷积效果,推导到公式(3)时才恍然大悟,原来是这么回事)。
那么MX的每一行值,就是公式三的结果了。我们再次看下公式(3):
x i k = ∑ j ∈ N ( i ) ∪ ( i ) 1 deg ( i ) deg ( j ) x j k − 1 Θ ( 3 ) x_i^{k}=\sum_{j\in N(i)\cup{(i)}}\frac{1}{\sqrt{\deg(i)}\sqrt{\text{deg}(j)}}x_j^{k-1}\Theta \ \ \ (3) xik=j∈N(i)∪(i)∑deg(i)deg(j)1xjk−1Θ (3)
分析出message function、reduce function为:
def gcn_msg(edge):
# 在边上的源节点上,乘以其度的正则
msg = edge.src['h'] * edge.src['norm']
return {'m': msg}
def gcn_reduce(node):
# 将目标节点的边上的信息聚合(这里是sum),再乘以目标节点上的度的正则
# 这里的torch.sum(, dim=1)在维度1上相加,是因为node.mailbox['m']的shape = [batch, mails, feat]
# 需要将mails整个加起来
accum = torch.sum(node.mailbox['m'], 1) * node.data['norm']
return {'h': accum} # 这时候,节点就存在一个数据node.data['h'] = accum
class NodeApplyModule(nn.Module):
def __init__(self, out_feats, activation=None, bias=True):
super(NodeApplyModule, self).__init__()
if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats))
else:
self.bias = None
self.activation = activation
self.reset_parameters()
def reset_parameters(self):
if self.bias is not None:
stdv = 1. / math.sqrt(self.bias.size(0))
self.bias.data.uniform_(-stdv, stdv)
def forward(self, nodes): # 要更新时,添加上偏置和激活函数
h = nodes.data['h']
if self.bias is not None:
h = h + self.bias
if self.activation:
h = self.activation(h)
return {'h': h} # 此时nodes['h'] = h 这是就被更新了
class GCNLayer(nn.Module):
def __init__(self,
g,
in_feats,
out_feats,
activation,
dropout,
bias=True):
super(GCNLayer, self).__init__()
self.g = g
self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
if dropout:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = 0.
self.node_update = NodeApplyModule(out_feats, activation, bias)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, h):
if self.dropout:
h = self.dropout(h)
self.g.ndata['h'] = torch.mm(h, self.weight) # 这里是首先将节点特征进行映射,也就是X*O
self.g.update_all(gcn_msg, gcn_reduce, self.node_update) # 然后求message,聚合,更新。
h = self.g.ndata.pop('h')
return h
# add self loop
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
n_edges = g.number_of_edges()
# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) # 传递到GCN的图已经存在了'norm'数据