DGL- Message Passing

Message Passing

这是一篇关于DGL消息传递的读后感。

Message Passing Paradigm

​ 这里, x v ∈ R d 1 x_v\in\mathbb{R}^{d_1} xvRd1是节点 v v v的特征, w e ∈ R d 2 w_e\in\mathbb{R}^{d_2} weRd2是边 ( u , v ) (u,v) (u,v)的特征。这个t+1时刻的消息传递框架定义如下的逐边(node-wise)和逐边的计算:
Edge-wise : m e t + 1 = ϕ ( x v t , x u t , w e t ) ,   ( u , v , e ) ∈ ε Node-wise : x v t + 1 = ψ ( x v t , ρ ( { m e t + 1 : ( u , v , e ) ∈ ε } ) ) \text{Edge-wise}:m_e^{t+1}=\phi(x_v^{t},x_u^{t},w_e^{t}),\ (u,v,e)\in\varepsilon \\ \text{Node-wise}:x_v^{t+1}=\psi(x_v^{t},\rho{\left(\left\{m_e^{t+1}:(u,v,e)\in \varepsilon\right\}\right)}) Edge-wise:met+1=ϕ(xvt,xut,wet), (u,v,e)εNode-wise:xvt+1=ψ(xvt,ρ({met+1:(u,v,e)ε}))
​ 在上面的等式中, ϕ \phi ϕ是定义在每条边上的消息函数,它通过结合边特征和两个节点来产生message; ψ \psi ψ是定义在每个节点上的更新函数,通过使用reduce functaion ρ \rho ρ聚合(aggregating)输入信息,从而更新此节点。

​ Note:上面对Message Passing Paradigm的解释包含了很多信息。接下来我尝试通过下面这个图来解释这个计算框架:
DGL- Message Passing_第1张图片

​ 在更新图节点的特征值,我们通常都是根据此节点的邻居节点更新的。假设我们想要更新节点6。那么,我们首先要知道的就是节点6和其邻居节点(2, 3, 7, 8, 9)通过相连的边传递的是什么信息,是相加吗?还是其它计算。此时这里的信息就是利用消息函数去完成的,消息函数通过 ( u , v , e ) (u,v,e) (u,v,e)在每条边上都生成了这个message,待会就利用这些信息去实施更新。

​ 要更新节点6了,我们要知道节点6的信息更新只能通过边 a 36 , a 26 , a 76 , a 68 , a 69 a_{36},a_{26},a_{76},a_{68},a_{69} a36,a26,a76,a68,a69(对于无向图来说,节点的顺序不重要;有向图的就集中一个方向就行)上的信息去更新。这时候reduce function ρ \rho ρ非常重要,它就是根据将节点6的邻居节点上的边信息收集起来,然后进行操作,可以是直接相加、相减、求和、平均、求权重和等等。那么最后,就利用 ψ \psi ψ去更新节点6的值。如果没听懂,我很抱歉,或许接下来的内容能让你重新了解。如果听懂了,我感到很欣慰,接下来的内容能让你更深刻的理解。

2.1 Built-in Functions and Message Passing APIs

​ 接下来分别介绍message functionreduce functionupdate function

  • message function:由于消息函数会在边上产生信息,那么,它需要一个edges参数。这个edges有三个成员,分别是src, dst, data。能偶用来访问边上源节点的特征,目标节点的特征和边本身的特征。例子如下,假设我们将边上src的hu特征和dst上的hv特征相加,然后保存到he上。

    def message_func(edges):  # 消息函数的参数为edge
        return {'he':edges.src['hu']+edges.dst['hv']}
    

    当然,dgl库本身也有处理这方面的内置函数:dgl.function.u_add_v('hu', 'hv', 'he')这里的u_add_v就表明了把源节点的特征和目标节点上的特征相加。

  • reduce function:需要有一个nodes参数。它有一个mailbox成员,能够用来访问这个节点收到的message。就像最开始讲的,一个节点只能够收到来自于邻居节点上的信息。所以它这个mailbox就存储了这些信息。所以,如果我们想把mailbox收到的message相加,然后存储到h里的话,也很简单:

    import torch
    def reduce_func(nodes):
        return {'h':torch.sum(nodes.mailbox['m'], dim=1)}# 这里之所以有'm',就是消息存储到'm'这个键值上了,就像这里的'h'
    

    当然,dgl也准备了内建函数dgl.function.sum('m', h)

  • update function:也需要一个参数nodes参数。它通常在最后一步去结合reduce function聚合的结果和目标节点的特征去更新目标节点的特征。

  • update_all():它是一个高阶函数,融合了消息的产生、聚合和节点的更新。所以,它需要三个参数:message function,reduce function和update function。也可以在update_all()函数外调用更新函数而不用在update特别指定。DGL推荐这种在update_all()外定义更新函数,因为为了代码的简洁,update function一般写成纯张量操作。具体实现例子如下,表达式为:将源节点j的特征ft_j与相连的所有边a的特征a_ij逐个相乘后再求和,最后将和乘以2。

    final_ft i = 2 ∗ ∑ j ∈ N ( i ) ( f t j ∗ a i j ) \text{final\_ft}_{i}=2*\sum_{j\in{N(i)}}(ft_j*a_{ij}) final_fti=2jN(i)(ftjaij)

    def update_all_example(graph):
        """
        	注意,这里的特征名称是一开始都设置好的。这个图本身包含了:
        	graph.ndata[ft]
        	graph.edata['a']
        	而m是临时用来存储message的
        """
        graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
        final_ft = graph.ndata['ft'] * 2
        return final_ft
    

    dgl.function里实现了很多了message functionreduce function

    接下来,介绍一个重点,update_all函数。使用这个函数,将极大简化代码。

    此处不再赘述。DGL库在message passing教程中接下来的内容都是关于如何优化使用和在不同场景下的使用,核心并没有改变。接下来,我尝试讲解DGL例子中使用Message Passing来构造GCN的例子,来更加清晰的使用Message Passing。

使用Message Passing构造GCN

​ 在这篇文章中,节点的更新过程为:
Z = D ~ − 1 2 A ~ D ~ − 1 2 X Θ      ( 2 ) Z=\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}X\Theta \ \ \ \ (2) Z=D 21A D 21XΘ    (2)
​ 这里 A ~ = A + I \widetilde{A}=A+I A =A+I,也就是邻接矩阵增加了自环。主要,在实践中,已经有自环的节点无需再增加自环。 D ~ i i = ∑ j A ~ i j \widetilde{D}_{ii}=\sum_{j}\widetilde{A}_{ij} D ii=jA ij度矩阵是增加自环之后求得的。

g = dgl.remove_self_loop(g) # 增加自环时,首先去除原本的自环
g = dgl.add_self_loop(g)
degs = g.in_degrees().float() #无向图中,入度和出度是相同的。
norm = torch.pow(degs, -0.5)

​ 接下来,我们需要将节点更新过程(2)用Message Passing来表达。首先,我们要知道Message Passing的思想就是在目标节点上求得edges上的信息,然后聚合起来更新目标节点。先给出最终表达式,有个目的性,然后再一步步推导:
x i k = ∑ j ∈ N ( i ) ∪ ( i ) 1 deg ⁡ ( i ) deg ( j ) x j k − 1 Θ     ( 3 ) x_i^{k}=\sum_{j\in N(i)\cup{(i)}}\frac{1}{\sqrt{\deg(i)}\sqrt{\text{deg}(j)}}x_j^{k-1}\Theta \ \ \ (3) xik=jN(i)(i)deg(i) deg(j) 1xjk1Θ   (3)
以上的表达式说明,在更新第 k k k层的第 i i i个节点特征时,将 k − 1 k-1 k1层第 i i i个节点特征与其邻居节点 j j j特征进行 Θ \Theta Θ转换、度的标准化,最后求和更新。这很符合图卷积的思想:将邻居节点的信息结合起来,更新目标节点。接下来将公式(2)进行分解到公式(3)。

​ 我们一步步看公式(2):
D ~ − 1 2 A ~ = [ 1 deg ⁡ ( 1 ) 0 ⋯ 0 0 1 deg ⁡ ( 2 ) ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ 0 0 0 ⋯ 1 deg ⁡ ( n ) ] ∗ A ~ \widetilde{D}^{-\frac{1}{2}}\widetilde{A} = \begin{bmatrix} \frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\ 0 & \frac{1}{\deg{(2)}} &\cdots &0\\ 0 & 0 & \cdots& 0\\ \vdots & \vdots &\ddots &0 \\ 0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}} \end{bmatrix}*\widetilde{A} D 21A =deg(1) 10000deg(2)1000000deg(n) 1A
​ Note: 1 deg ⁡ ( i ) \frac{1}{\sqrt{\deg(i)}} deg(i) 1是节点 i i i的度的标准化。

​ 这里, D ~ − 1 2 A ~ \widetilde{D}^{-\frac{1}{2}}\widetilde{A} D 21A 就相当于将 A ~ \widetilde{A} A 的第 i i i行的值乘以节点 i i i的度。要进行下一步操作时,我们首先要搞清楚邻接矩阵(此文章增加了自环,但我们仍然用邻接矩阵称呼它) A ~ \widetilde{A} A 的意义。如下,
{ a i j = 1 当 节 点 j 是 节 点 i 的 邻 居 节 点 , 那 么 第 i 行 的 第 j 列 为 1 a i j = 0 其 它 \begin{cases} a_{ij}=1 & 当节点j是节点i的邻居节点,那么第i行的第j列为1 \\ a_{ij}=0 & 其它\\ \end{cases} {aij=1aij=0jiij1

​ 我们令 D ~ − 1 2 A ~ D ~ − 1 2 = M \widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}=M D 21A D 21=M
D ~ − 1 2 A ~ D ~ − 1 2 = [ 1 deg ⁡ ( 1 ) 0 ⋯ 0 0 1 deg ⁡ ( 2 ) ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ 0 0 0 ⋯ 1 deg ⁡ ( n ) ] ∗ A ~ ∗ [ 1 deg ⁡ ( 1 ) 0 ⋯ 0 0 1 deg ⁡ ( 2 ) ⋯ 0 0 0 ⋯ 0 ⋮ ⋮ ⋱ 0 0 0 ⋯ 1 deg ⁡ ( n ) ] [ a 11 1 deg ⁡ ( 1 ) ∗ 1 deg ⁡ ( 1 ) a 12 1 deg ⁡ ( 1 ) 1 deg ⁡ ( 2 ) ⋯ a 1 n 1 deg ⁡ ( 1 ) ∗ 1 deg ⁡ ( n ) a 21 1 deg ⁡ ( 2 ) ∗ 1 deg ⁡ ( 1 ) a 22 1 deg ⁡ ( 2 ) 1 deg ⁡ ( 2 ) ⋯ a 2 n 1 deg ⁡ ( 2 ) ∗ 1 deg ⁡ ( n ) ⋮ ⋮ ⋱ ⋮ a n 1 1 deg ⁡ ( n ) ∗ 1 deg ⁡ ( 1 ) a n 2 1 deg ⁡ ( n ) 1 deg ⁡ ( 2 ) ⋯ a n n 1 deg ⁡ ( n ) ∗ 1 deg ⁡ ( n ) ] \widetilde{D}^{-\frac{1}{2}}\widetilde{A} \widetilde{D}^{-\frac{1}{2}}=\begin{bmatrix} \frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\ 0 & \frac{1}{\deg{(2)}} &\cdots &0\\ 0 & 0 & \cdots& 0\\ \vdots & \vdots &\ddots &0 \\ 0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}} \end{bmatrix}*\widetilde{A}*\begin{bmatrix} \frac{1}{\sqrt{\deg{(1)}}} & 0 & \cdots & 0\\ 0 & \frac{1}{\deg{(2)}} &\cdots &0\\ 0 & 0 & \cdots& 0\\ \vdots & \vdots &\ddots &0 \\ 0 & 0 & \cdots & \frac{1}{\sqrt{\deg{(n)}}} \end{bmatrix} \\ \begin{bmatrix}a_{11}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{12}\frac{1}{\sqrt{\deg{(1)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{1n}\frac{1}{\sqrt{\deg{(1)}}}*\frac{1}{\sqrt{\deg{(n)}}} \\a_{21}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{22}\frac{1}{\sqrt{\deg{(2)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{2n}\frac{1}{\sqrt{\deg{(2)}}}*\frac{1}{\sqrt{\deg{(n)}}}\\ \vdots & \vdots &\ddots &\vdots \\a_{n1}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(1)}}} & a_{n2}\frac{1}{\sqrt{\deg{(n)}}}\frac{1}{\sqrt{\deg{(2)}}} & \cdots & a_{nn}\frac{1}{\sqrt{\deg{(n)}}}*\frac{1}{\sqrt{\deg{(n)}}} \end{bmatrix} D 21A D 21=deg(1) 10000deg(2)1000000deg(n) 1A deg(1) 10000deg(2)1000000deg(n) 1a11deg(1) 1deg(1) 1a21deg(2) 1deg(1) 1an1deg(n) 1deg(1) 1a12deg(1) 1deg(2) 1a22deg(2) 1deg(2) 1an2deg(n) 1deg(2) 1a1ndeg(1) 1deg(n) 1a2ndeg(2) 1deg(n) 1anndeg(n) 1deg(n) 1
​ 大家看到这里应该比较清楚了,这里的系数 a i j a_{ij} aij只能为0或者1,并且取决于 j j j是否为 i i i的邻居节点。 X Θ X\Theta XΘ仅仅只是将特征 X X X进行特征映射了,仅仅只是改变X的列维度。那么,接下来 M X MX MX的值就意味着矩阵M使用系数去选取X的值,这样就实现了选取邻居节点特征的意义。(害,其实我刚看到公式(2)的时候,内心是崩溃的,它怎么就实现抽取邻居节点特征的卷积效果,推导到公式(3)时才恍然大悟,原来是这么回事)。

​ 那么MX的每一行值,就是公式三的结果了。我们再次看下公式(3):
x i k = ∑ j ∈ N ( i ) ∪ ( i ) 1 deg ⁡ ( i ) deg ( j ) x j k − 1 Θ     ( 3 ) x_i^{k}=\sum_{j\in N(i)\cup{(i)}}\frac{1}{\sqrt{\deg(i)}\sqrt{\text{deg}(j)}}x_j^{k-1}\Theta \ \ \ (3) xik=jN(i)(i)deg(i) deg(j) 1xjk1Θ   (3)
​ 分析出message function、reduce function为:

  • message function:每个源节点特征乘以其度的正则
  • reduce function:将message function产生的信息求和,并且乘以目标节点的度正则。
def gcn_msg(edge):
    # 在边上的源节点上,乘以其度的正则
    msg = edge.src['h'] * edge.src['norm']
    return {'m': msg} 
def gcn_reduce(node):
    # 将目标节点的边上的信息聚合(这里是sum),再乘以目标节点上的度的正则
    # 这里的torch.sum(, dim=1)在维度1上相加,是因为node.mailbox['m']的shape = [batch, mails, feat]
    # 需要将mails整个加起来
    accum = torch.sum(node.mailbox['m'], 1) * node.data['norm'] 
    return {'h': accum} # 这时候,节点就存在一个数据node.data['h'] = accum
class NodeApplyModule(nn.Module):
    def __init__(self, out_feats, activation=None, bias=True):
        super(NodeApplyModule, self).__init__()
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None
        self.activation = activation
        self.reset_parameters()

    def reset_parameters(self):
        if self.bias is not None:
            stdv = 1. / math.sqrt(self.bias.size(0))
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, nodes): # 要更新时,添加上偏置和激活函数
        h = nodes.data['h']
        if self.bias is not None:
            h = h + self.bias
        if self.activation:
            h = self.activation(h)
        return {'h': h} # 此时nodes['h'] = h 这是就被更新了
class GCNLayer(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 out_feats,
                 activation,
                 dropout,
                 bias=True):
        super(GCNLayer, self).__init__()
        self.g = g
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
        self.node_update = NodeApplyModule(out_feats, activation, bias)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, h):
        if self.dropout:
            h = self.dropout(h)
        self.g.ndata['h'] = torch.mm(h, self.weight) # 这里是首先将节点特征进行映射,也就是X*O
        self.g.update_all(gcn_msg, gcn_reduce, self.node_update) # 然后求message,聚合,更新。
        h = self.g.ndata.pop('h')
        return h
# add self loop
g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g)
n_edges = g.number_of_edges()

# normalization
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) # 传递到GCN的图已经存在了'norm'数据

你可能感兴趣的:(图神经网络,神经网络)