DGL学习笔记-1

笔记是直接从Jupyter 保存下来的,格式会比较乱,主要是给自己看的。

1. 构建图

导入依赖库

import dgl
import torch
import numpy as np

1.1 构建图

src_idx = np.random.randint(0,3,5)  # sorce nodes id
dst_idx = np.random.randint(0,3,5)  # destination nodes id

G = dgl.graph((src_idx, dst_idx))   # construct a graph

# print number of nodes
print("number of nodes: ", G.num_nodes())
# print number of edges
print("number of edges: ", G.num_edges())

增加图的节点特征:

# add node features
node_feat = torch.rand(G.num_nodes(),4)
G.ndata['x'] = node_feat

查看图的边数据和节点数据:

G.edata, G.ndata

结果如下:

({},
 {'x': tensor([[0.6212, 0.5459, 0.1122, 0.6694],
         [0.2265, 0.1132, 0.1532, 0.3810],
         [0.4304, 0.3933, 0.6462, 0.9284]])})

2. apply_edges 方法

该方法主要是用于 对每条边进行计算。

# 定义一个消息传递函数,函数指定一个 edges 对象,该对象具有:src、dst、data 三个属性。
# src: 这条边的sorce节点,src['x'] 表示sorce节点的x特征
# dst:这条边的destination节点
# data:这条边的属性,data['e']表示变得e特征
def fn_es(edges):
    return {'e': edges.src['x'] + edges.dst['x']}
G.apply_edges(fn_es)

G

输出结果:

Graph(num_nodes=3, num_edges=5,
      ndata_schemes={'x': Scheme(shape=(4,), dtype=torch.float32)}
      edata_schemes={'e': Scheme(shape=(4,), dtype=torch.float32)})

3. apply_nodes 方法

该方法是用于 对每个节点进行计算。

# 定义一个聚合函数,函数指定一个 nodes 对象,该对象具有:mailbox、data 属性。
# mailbox:存储了通过消息传递到节点的信息,这里暂时用不到
# data:保持了该节点的特征
def fn_ns(nodes):
    print(nodes.mailbox)  # 查看一下mailbox
    return {'h': nodes.data['x'] + 1}

G.apply_nodes(fn_ns)

G.ndata

输出结果如下:

None

{'x': tensor([[0.6212, 0.5459, 0.1122, 0.6694],
        [0.2265, 0.1132, 0.1532, 0.3810],
        [0.4304, 0.3933, 0.6462, 0.9284]]), 'h': tensor([[1.6212, 1.5459, 1.1122, 1.6694],
        [1.2265, 1.1132, 1.1532, 1.3810],
        [1.4304, 1.3933, 1.6462, 1.9284]])}

官方文档的例子(apply_nodes可以用lambda定义函数;可以用v参数,只对特定节点进行计算)。

g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['h'] = torch.ones(5, 2)
g.apply_nodes(lambda nodes: {'x' : nodes.data['h'] * 2},
             v=(1,2)  # 只对id为1、2的两个节点进行计算
             )
g.ndata['x']

结果如下:

tensor([[0., 0.],
        [2., 2.],
        [2., 2.],
        [0., 0.],
        [0., 0.]])

4. 消息传递

消息函数 接受一个参数 edges,这是一个 EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。 edges 有 src、 dst 和 data 共3个成员属性, 分别用于访问源节点、目标节点和边的特征。

4.1 DGL内置的消息传递函数

内置消息函数可以是一元函数或二元函数。对于一元函数,DGL支持 copy 函数。对于二元函数, DGL现在支持 add、 sub、 mul、 div、 dot 函数。消息的内置函数的命名约定是 u 表示 源 节点, v 表示 目标 节点,e 表示 边。

# hu 指的是source节点的 hu 属性特征。
# hv 指的是destination节点的 hv 属性特征。
# he 指的是将 hu + hv 结果保存于 边 上,并命名为 he 特征。

dgl.function.u_add_v('hu', 'hv', 'he')

4.2 自定义消息传递函数

# hu 指的是source节点的 hu 属性特征。
# hv 指的是destination节点的 hv 属性特征。
# he 指的是将 hu + hv 结果保存于 边 上,并命名为 he 特征。

def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}

5. 消息聚合

聚合函数 接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum、max、min 等。

5.1 DGL内置的聚合函数

DGL支持内置的聚合函数 sum、 max、 min 和 mean 操作。 聚合函数通常有两个参数,它们的类型都是字符串。一个用于指定 mailbox 中的字段名,一个用于指示目标节点特征的字段名, 例如:

# m 指的是保存在 nodes 的 mailbox 属性的字段名
# h 指的是输出的字段名

dgl.function.sum('m', 'h')

5.2 自定义消息传递函数

# m 指的是保存在 nodes 的 mailbox 属性的字段名
# h 指的是输出的字段名

import torch
def reduce_func(nodes):
     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

6. 更新函数

更新函数 接受一个如上所述的参数 nodes。此函数对 聚合函数 的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。

7. update_all

高级API,在该API中,实现了消息传递、消息聚合、节点更新

update_all() 的参数是一个消息函数、一个聚合函数和一个更新函数。 更新函数是一个可选择的参数,用户也可以不使用它,而是在 update_all 执行完后直接对节点特征进行操作。 由于更新函数通常可以用纯张量操作实现,所以DGL不推荐在 update_all 中指定更新函数。

# 此调用通过将源节点特征 ft 与边特征 a 相乘生成消息 m, 然后对所有消息求和来更新节点特征 ft,再将 ft 乘以2得到最终结果 final_ft。
# 调用后,中间消息 m 将被清除。

import dgl.function as fn
def update_all_example(graph):
    # 在graph.ndata['ft']中存储结果
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # 在update_all外调用更新函数
    final_ft = graph.ndata['ft'] * 2
    return final_ft

使用案例如下:

# 在前文定义的G图上,增加节点特征,以便于应用update_all_example函数。
G.ndata['ft'] = G.ndata['x']
G.edata['a'] = G.edata['e']

update_all_example(G)

输出结果如下:

tensor([[2.4490, 1.9310, 1.0305, 4.7589],
        [1.1309, 0.7969, 2.0665, 4.8625],
        [1.3067, 1.0255, 0.1701, 2.1390]])

你可能感兴趣的:(深度学习,学习,python,pytorch)