图神经网络框架DGL中的 消息函数、聚合函数及更新函数 的理解与说明

在DGL框架中,当我们明白了边、顶点、图以及边和顶点的属性后,接下来需要了解的概念就是 三座大山了:

  • 消息函数
  • 聚合函数
  • 更新函数

mailbox

dgl的节点有一个概念是mailbox, 用一暂存消息函数发送过来的数据。 消息-> 邮箱, 这样的概念是比较容易理解的。

 

消息函数

默认的消息函数是 ϕ,它接受的参数是edges ,类型是dgl.EdgeBatch.   edges有src,dst和data三个属性,分别是源顶点、目标顶点和边,可以用这三个属性访问各自的特征。

表示:  node + node -> mailbox  或  Node + edge -> mailbox

 

内置消息函数: 

一元: copy 函数

二元: add, sub, mul, div, dot 函数

约定: 名字上的u表示源节点,v表示目标节点,e表示边。

参数: 字符串参数, 表示相应节点的输入和输出特征名字段名

例: 对源节点的hu特征和目标节点的hv特征求和,然后将结果保存在边的he特征上

dgl.function.u_add_v('hu','hv','he')

如果要自定义此消息函数,等价于以下代表,注意返回的是dict格式的数据。

def message_func(edges):
	 return {'he': edges.src['hu'] + edges.dst['hv']}		

 

聚合函数

 

默认的聚合函数是ρ , 接受的参数类型是nodes ,也就是顶点集合,类型为 dgl.NodeBatch, nodes有成员属性mailbox, 用来访问节点收到的消息。 mailbox可以理解为一块临时存贮区,在消息函数运行后用来暂存数据。 此时并不会更新目标节点数据。

内置的聚合函数:

sum, max,min,mean 操作

参数:这些函数通常都是两个参数,类型为字符串

一个用于指定mailbox中的字段名

一个用于指定目标节点特征的字段名

如dgl.function.sum('m','h')等价于如下所示的自定义函数。 注意,聚合只是聚合,并不更新任何值 ,只执行聚合的任务,说白了就是把消息函数中传来的数据进行处理但不更新,切记。

import torch
def reduce_func(nodes):
	 return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

 

更新函数

 

前两步接收和聚合后的数据,需要更新目标节点的特征,参数为nodes, 类型为dgl.NodeBatch.  此函数对聚合函数的聚合结果进行操作,在消息传递的最后一步将其与其它节点的特征组合后,作为节点的新特征。

前面讲了消息和mailbox的概念,更新函数的作用就是按需将mailbox中的数据搬回家(与节点的数据合并)

apply_edeges()

在不涉及消息传递时,可以调用apply_edges()函数进行逐边计算

参数为一个消息函数,默认为更新所有的边

例子:

import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))	

update_all()

该接口合并了消息生成,消息聚合,节点特重更新,好处是可以给这三步操作作一个整体优化,用更底层的高效算法,比如直接调用cuda函数进行操作,从而提高运行效率。

参数为: 一个消息函数,一个聚合函数,一个更新函数。 官方文档不建议在这儿使用更新函数,可以自己在随后进行操作。

示例:

def updata_all_example(graph):
	# store the result in graph.ndata['ft']
	graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
					 fn.sum('m', 'ft'))
	# Call update function outside of update_all
	final_ft = graph.ndata['ft'] * 2
	return final_ft

这段代码 将节点 特征字段 ft与边特征字段a相乘后生成消息m (存放于暂存位置mailbox中),然后对所有的消息求和来更新节点特征ft, 再将ft乘以2得到最终结果 final_ft, 调用后mailbox中的中间结果m会被清除。公式表示为:

 

你可能感兴趣的:(图神经网络,图神经网络DGL)