最近有用到DGL来手写GNN算子,因此需要简单理解一下消息传递范式。
对于一个图 G G G来说,节点 v v v的特征 x v x_v xv,边 ( u , v ) (u,v) (u,v)上的特征为 w e w_e we,那么DGL定义了如下消息传递范式:
在DGL中,消息函数接受一个edges(类型为dgl.EdgeBatch)参数,该参数由src、dst和data三个属性组成,对应 ( u , v , e ) (u, v, e) (u,v,e)。
利用消息函数,我们可以将两个节点特征相结合然后赋给边,也可以将节点和边特征相结合然后赋给边,即:node+node->edge或node+edge->edge。
内置消息函数可以是一元函数或二元函数。对于一元函数,DGL支持 copy
函数。对于二元函数,DGL现在支持 add
、 sub
、 mul
、 div
、 dot
函数。消息的内置函数的命名约定是 u
表示 源
节点,v
表示 目标
节点,e
表示 边
。这些函数的参数是字符串,指示相应节点和边的输入和输出特征字段名。
例如,要对源节点的 hu
特征和目标节点的 hv
特征求和,然后将结果保存在边的 he
特征上,用户可以使用内置函数 dgl.function.u_add_v('hu', 'hv', 'he')
。而以下用户定义消息函数与此内置函数等价:
def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
此外,在DGL中也可以在不涉及消息传递的情况下,通过apply_edges()单独调用逐边计算:
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
DGL内置的聚合函数包括sum、max、min以及mean。聚合函数通常有两个参数,它们的类型都是字符串。一个用于指定mailbox
中的字段名,一个用于指示目标节点特征的字段名。其中,NodeBatch.mailbox是用来暂存消息函数发来的数据。
例如dgl.function.sum('m', 'h')
等价于:
import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
即将节点收到的消息中的m特征求和,然后暂存在节点的mailbox中,用h指示。最后该h特征和节点自身的某个特征结合用以更新节点。
更新函数的作用就是将经过聚合后的节点的mailbox中的数据和节点本身特征相结合用以更新节点。
update_all()
是DGL内置的一个函数,可以同时实现消息生成、消息聚合以及更新,它的参数由三个:一个消息函数、一个聚合函数和一个更新函数。 更新函数是一个可选择的参数,用户也可以不使用它,而是在update_al()
执行完后直接对节点特征进行操作。例如:
def update_all_example(graph):
# 在graph.ndata['ft']中存储结果
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# 在update_all外调用更新函数
final_ft = graph.ndata['ft'] * 2
return final_ft
上面代码中,消息函数的作用是将源节点的ft特征与目标节点的a特征相乘,然后存放到边的m特征上,实际上是将该结果存放到mailbox[‘m’]中。然后,聚合函数将节点接收到的所有mailbox的m特征,也就是mailbox[‘m’]求和,放到mailbox[‘ft’]里面。最后,对源节点的ft特征进行更新,也就是聚合结果*2得到更新后的特征。
利用DGL中的消息传递范式来手搭一个GCN用于节点分类,请见下一篇文章。