DGL_图的创建、保存、加载

import dgl
import torch as th
from dgl.data.utils import save_graphs

g1 = dgl.DGLGraph()
g1.add_nodes(3)
g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
g1.ndata["x"] = th.ones(3, 5)   # 3个节点的embedding
g1.edata['y'] = th.zeros(6, 5)  # 6条边的embedding
# 补充:添加边的方式
# g1.add_edges(th.tensor([3, 4, 5]), 1)  # three edges: 3->1, 4->1, 5->1
# g1.add_edges(4, [7, 8, 9])  # three edges: 4->7, 4->8, 4->9
# g1.add_edges([1, 2, 3], [3, 4, 5])  # three edges: 1->3, 2->4, 3->5

g2 = dgl.DGLGraph()
g2.add_nodes(3)
g2.add_edges([0, 1, 2], [1, 2, 1])
g2.edata["e"] = th.ones(3, 4)

graph_labels = {"graph_sizes": th.tensor([3, 3])}

save_graphs("data/try1.bin", [g1, g2], graph_labels)
from dgl.data.utils import load_graphs
from dgl.data.utils import load_labels

# glist, label_dict = load_graphs("data/small.bin") # glist will be [g1, g2]
glist, label_dict = load_graphs("data/try1.bin", [0]) # glist will be [g1]
graph_sizes = load_labels("data/try1.bin")

print(glist)
# [DGLGraph(num_nodes=3, num_edges=6,
#          ndata_schemes={'x': Scheme(shape=(5,), dtype=torch.float32)}
#          edata_schemes={'y': Scheme(shape=(5,), dtype=torch.float32)})]
print(label_dict)
# {'graph_sizes': tensor([3, 3])}
print(graph_sizes)
# {'graph_sizes': tensor([3, 3])}

你可能感兴趣的:(DGL,数据挖掘)