详解DGL中的消息传递API

目录

  • 1. 前言
  • 2. DGL消息传递范式
    • 2.1 消息函数
    • 2.2 聚合函数
    • 2.3 更新函数
    • 2.4 update_all()
  • 案例:GCN

1. 前言

最近有用到DGL来手写GNN算子,因此需要简单理解一下消息传递范式。

2. DGL消息传递范式

对于一个图 G G G来说,节点 v v v的特征 x v x_v xv,边 ( u , v ) (u,v) (u,v)上的特征为 w e w_e we,那么DGL定义了如下消息传递范式:

  1. 消息函数
    m e ( t + 1 ) = ϕ ( x v ( t ) , x u ( t ) , w e ( t ) ) , ( u , v , e ) ∈ E . m_{e}^{(t+1)} = \phi \left( x_v^{(t)}, x_u^{(t)}, w_{e}^{(t)} \right) , ({u}, {v},{e}) \in \mathcal{E}. me(t+1)=ϕ(xv(t),xu(t),we(t)),(u,v,e)E.
    函数 ϕ \phi ϕ为定义在每条边上的消息函数,它将边的特征与两个节点的特征相结合来生成边上的消息。
  2. 点计算
    x v ( t + 1 ) = ψ ( x v ( t ) , ρ ( { m e ( t + 1 ) : ( u , v , e ) ∈ E } ) ) . x_v^{(t+1)} = \psi \left(x_v^{(t)}, \rho\left(\left\lbrace m_{e}^{(t+1)} : ({u}, {v},{e}) \in \mathcal{E} \right\rbrace \right) \right). xv(t+1)=ψ(xv(t),ρ({me(t+1):(u,v,e)E})).
    聚合函数 ρ \rho ρ会聚合节点接受到的消息,更新函数 ψ \psi ψ会结合聚合后的消息和节点本身的特征来更新节点。

2.1 消息函数

在DGL中,消息函数接受一个edges(类型为dgl.EdgeBatch)参数,该参数由src、dst和data三个属性组成,对应 ( u , v , e ) (u, v, e) (u,v,e)

利用消息函数,我们可以将两个节点特征相结合然后赋给边,也可以将节点和边特征相结合然后赋给边,即:node+node->edge或node+edge->edge。

内置消息函数可以是一元函数或二元函数。对于一元函数,DGL支持 copy 函数。对于二元函数,DGL现在支持 addsubmuldivdot 函数。消息的内置函数的命名约定是 u 表示 节点,v 表示 目标 节点,e 表示 这些函数的参数是字符串,指示相应节点和边的输入和输出特征字段名

例如,要对源节点的 hu 特征和目标节点的 hv 特征求和,然后将结果保存在边的 he 特征上,用户可以使用内置函数 dgl.function.u_add_v('hu', 'hv', 'he')。而以下用户定义消息函数与此内置函数等价:

def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}

此外,在DGL中也可以在不涉及消息传递的情况下,通过apply_edges()单独调用逐边计算:

import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

2.2 聚合函数

DGL内置的聚合函数包括sum、max、min以及mean。聚合函数通常有两个参数,它们的类型都是字符串。一个用于指定mailbox 中的字段名,一个用于指示目标节点特征的字段名。其中,NodeBatch.mailbox是用来暂存消息函数发来的数据。

例如dgl.function.sum('m', 'h')等价于:

import torch
def reduce_func(nodes):
     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

即将节点收到的消息中的m特征求和,然后暂存在节点的mailbox中,用h指示。最后该h特征和节点自身的某个特征结合用以更新节点。

2.3 更新函数

更新函数的作用就是将经过聚合后的节点的mailbox中的数据和节点本身特征相结合用以更新节点。

2.4 update_all()

update_all()是DGL内置的一个函数,可以同时实现消息生成、消息聚合以及更新,它的参数由三个:一个消息函数、一个聚合函数和一个更新函数。 更新函数是一个可选择的参数,用户也可以不使用它,而是在update_al()执行完后直接对节点特征进行操作。例如:

def update_all_example(graph):
    # 在graph.ndata['ft']中存储结果
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # 在update_all外调用更新函数
    final_ft = graph.ndata['ft'] * 2
    return final_ft

上面代码中,消息函数的作用是将源节点的ft特征与目标节点的a特征相乘,然后存放到边的m特征上,实际上是将该结果存放到mailbox[‘m’]中。然后,聚合函数将节点接收到的所有mailbox的m特征,也就是mailbox[‘m’]求和,放到mailbox[‘ft’]里面。最后,对源节点的ft特征进行更新,也就是聚合结果*2得到更新后的特征。

案例:GCN

利用DGL中的消息传递范式来手搭一个GCN用于节点分类,请见下一篇文章。

你可能感兴趣的:(DGL,dgl,GNN,图神经网络)