看GAT之前,先看看DGL库自带的一些函数,遵循消息传播范式,DGL自带了很多消息函数和传播函数,都在Builtin message passing functions中,总的来说,DGL由两个api组成消息传播:
send(edges, message_func) :for computing the messages along the given edges
recv(nodes, reduce_func) : for collecting the in-coming messages, perform aggregation and so on.
细看这里的两个函数:
DGLGraph.send(edges='__ALL__', message_func='default')
:沿着给定的边发送消息,发送的方式是一个message_func
函数定义,返回边缘上的消息,稍后可以在目标节点(destination node)的邮箱(mailbox
)中获取。接收将消耗消息。
DGLGraph.recv(v='__ALL__', reduce_func='default', apply_node_func='default', inplace=False)
:接收和变换( reduce_func
)传入的消息,并更新节点v
的特征。可选的apply_node_func
函数在接收后更新节点特性。
例子如 https://docs.dgl.ai/generated/dgl.DGLGraph.recv.html#dgl.DGLGraph.recv 所示。
dgl.DGLGraph.update_all(message_func='default', reduce_func='default', apply_node_func='default')
:该函数通过所有边发送消息并更新所有节点。可选地,应用一个函数在接收后更新节点特征。这是上面两个函数的组合:send(self, self.edges(), message_func)
还有recv(self, self.nodes(), reduce_func, apply_node_func)
。
dgl.DGLGraph.apply_nodes(func='default', v='__ALL__', inplace=False)
:参数func
是应用在节点上的函数,用以更新节点特征,如下增加节点特征的例子:
import torch as th
g = dgl.DGLGraph()
g.add_nodes(3)
g.ndata['x'] = th.ones(3,1)
def fun(nodes): return {'x': nodes.datas['x'] + 1}
g.apply_nodes(func=fun,v = [0, 2])
print(g.ndata)
{'x': tensor([[2.],
[1.],
[2.]])}
同理有dgl.DGLGraph.apply_edgess(func='default', edges='__ALL__', inplace=False)
函数。例子在 https://docs.dgl.ai/generated/dgl.DGLGraph.apply_edges.html?highlight=update。
其他函数:
dgl.DGLGraph.set_n_initializer
默认的初始化:
G.set_n_initializer(dgl.init.zero_initializer)
G.ndata.update
:D.update([E, ]**F) -> None. Update D from mapping/iterable E and F.G.edata.pop
:D.pop(k[,d]) -> v, remove specified key and return the corresponding value.:附一个表,参考 Builtin message passing functions
Category | Functions | Memo |
---|---|---|
Unary message function | copy_u |
|
copy_e |
||
copy_src |
alias of copy_u |
|
copy_edge |
alias of copy_e |
|
Binary message function | u_add_v , u_sub_v , u_mul_v , u_div_v |
|
u_add_e , u_sub_e , u_mul_e , u_div_e |
||
v_add_u , v_sub_u , v_mul_u , v_div_u |
||
v_add_e , v_sub_e , v_mul_e , v_div_e |
||
e_add_u , e_sub_u , e_mul_u , e_div_u |
||
e_add_v , e_sub_v , e_mul_v , e_div_v |
||
src_mul_edge |
alias of u_mul_e |
|
Reduce function | max |
|
min |
||
sum |
||
prod |
放别处了: