在PyG搭建GCN前的准备:了解PyG中的数据格式中讲解了PyG中的数据格式,DGL是与PyG齐名的另一大图神经网络框架,二者各有优缺点,建议都学习并掌握。
本篇文章使用Citeseer网络。Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。
DGL加载Citeseer网络:
import dgl
from dgl.data.citation_graph import CiteseerGraphDataset
dataset = CiteseerGraphDataset()
print(len(dataset))
输出为1,说明只有一个网络,然后我们输出一下这个网络:
graph = dataset[0]
print(type(graph))
print(graph)
输出:
<class 'dgl.heterograph.DGLHeteroGraph'>
Graph(num_nodes=3327, num_edges=9228,
ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'feat': Scheme(shape=(3703,), dtype=torch.float32)}
edata_schemes={'__orig__': Scheme(shape=(), dtype=torch.int64)})
在DGL中,所有图都为dgl.DGLGraph格式,为了创建图,可以有以下两种方法:dgl.graph()和dgl.heterograph(),这两个方法分别创建同质图和异质图,二者返回的都是dgl.DGLGraph。
因此,下面先了解一下dgl.DGLGraph。
dgl.DGLGraph类有以下属性和方法:
首先是属性:
DGLGraph.ntypes
返回图中所有类型节点的名称,如上面的网络返回:
['_N']
表明只有一种类型的节点,即论文节点。
同理还有边类型:
DGLGraph.etypes
输出为:
['_E']
同样只有一种类型的边。此外,DGLGraph.srctypes和DGLGraph.dsttypes分别返回源节点和目标节点的类型。
print(graph.metagraph()) # 返回异质图的元图
MultiDiGraph with 1 nodes and 1 edges
print(graph.num_nodes()) # 返回节点数
3327
print(graph.num_edges()) # 返回边数
9228
DGLGraph.nodes()返回节点集合:
print(graph.nodes())
tensor([ 0, 1, 2, ..., 3324, 3325, 3326])
边集合:
print(graph.edges())
(tensor([ 2, 3, 0, ..., 3323, 3326, 3325]), tensor([ 0, 0, 0, ..., 3324, 3325, 3326]))
同样是两个列表,分别对应两端节点编号。
ndata返回节点上的一些信息:
print(graph.ndata)
{'train_mask': tensor([False, True, False, ..., False, False, True]), 'label': tensor([1, 4, 1, ..., 5, 3, 3]), 'val_mask': tensor([False, False, False, ..., False, False, False]), 'test_mask': tensor([False, False, False, ..., False, False, False]), 'feat': tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]])}
feat表示节点特征。
同理有edata,这里不再详细讲解。
利用dgl.graph()方法构建同质图:
dgl.graph(data, ntype=None, etype=None, *, num_nodes=None, idtype=None, device=None, row_sorted=False , col_sorted=False, **deprecated_kwargs )
其中data的形式为(U, V),表示边的两边节点集合;num_nodes表示节点数,如果没有给出,将使用data中的最大id+1,这可能会引发错误:
src_ids = torch.tensor([2, 3, 4])
# Destination nodes for edges (2, 1), (3, 2), (4, 3)
dst_ids = torch.tensor([1, 2, 3])
g = dgl.graph((src_ids, dst_ids))
print(g.num_nodes())
返回的节点数为5,如果给定num_nodes<=4,将引发错误。
dgl.heterograph( data_dict , num_nodes_dict=None , idtype=None , device=None )
具体例子:
data_dict = {
('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),
('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))
}
g = dgl.heterograph(data_dict)
print(g)
上图中,一共有user、topic和game三种类型的节点,他们有三种类型的关系,右边的数据表示边两边节点的索引。
此外,可以显式指定节点个数:
num_nodes_dict = {'user': 4, 'topic': 4, 'game': 6}
g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)
这里一样指定数目不能小于边集合中的最小索引+1。