了解DGL中的数据格式

目录

  • 前言
  • 1. DGL数据集
    • 1.1 dgl.DGLGraph
    • 1.2 dgl.graph()
    • 1.3 dgl.heterograph()

前言

在PyG搭建GCN前的准备:了解PyG中的数据格式中讲解了PyG中的数据格式,DGL是与PyG齐名的另一大图神经网络框架,二者各有优缺点,建议都学习并掌握。

1. DGL数据集

本篇文章使用Citeseer网络。Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。

DGL加载Citeseer网络:

import dgl
from dgl.data.citation_graph import CiteseerGraphDataset

dataset = CiteseerGraphDataset()
print(len(dataset))

输出为1,说明只有一个网络,然后我们输出一下这个网络:

graph = dataset[0]
print(type(graph))
print(graph)

输出:

<class 'dgl.heterograph.DGLHeteroGraph'>
Graph(num_nodes=3327, num_edges=9228,
      ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'feat': Scheme(shape=(3703,), dtype=torch.float32)}
      edata_schemes={'__orig__': Scheme(shape=(), dtype=torch.int64)})

在DGL中,所有图都为dgl.DGLGraph格式,为了创建图,可以有以下两种方法:dgl.graph()和dgl.heterograph(),这两个方法分别创建同质图和异质图,二者返回的都是dgl.DGLGraph。

因此,下面先了解一下dgl.DGLGraph。

1.1 dgl.DGLGraph

dgl.DGLGraph类有以下属性和方法:

首先是属性:

DGLGraph.ntypes

返回图中所有类型节点的名称,如上面的网络返回:

['_N']

表明只有一种类型的节点,即论文节点。

同理还有边类型:

DGLGraph.etypes

输出为:

['_E']

同样只有一种类型的边。此外,DGLGraph.srctypes和DGLGraph.dsttypes分别返回源节点和目标节点的类型。

print(graph.metagraph())   # 返回异质图的元图
MultiDiGraph with 1 nodes and 1 edges
print(graph.num_nodes())   # 返回节点数
3327
print(graph.num_edges())   # 返回边数
9228

DGLGraph.nodes()返回节点集合:

print(graph.nodes())
tensor([   0,    1,    2,  ..., 3324, 3325, 3326])

边集合:

print(graph.edges())
(tensor([   2,    3,    0,  ..., 3323, 3326, 3325]), tensor([   0,    0,    0,  ..., 3324, 3325, 3326]))

同样是两个列表,分别对应两端节点编号。

ndata返回节点上的一些信息:

print(graph.ndata)
{'train_mask': tensor([False,  True, False,  ..., False, False,  True]), 'label': tensor([1, 4, 1,  ..., 5, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ..., False, False, False]), 'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}

feat表示节点特征。

同理有edata,这里不再详细讲解。

1.2 dgl.graph()

利用dgl.graph()方法构建同质图:

dgl.graph(data, ntype=None, etype=None, *, num_nodes=None, idtype=None, device=None, row_sorted=False , col_sorted=False, **deprecated_kwargs )

其中data的形式为(U, V),表示边的两边节点集合;num_nodes表示节点数,如果没有给出,将使用data中的最大id+1,这可能会引发错误:

src_ids = torch.tensor([2, 3, 4])
# Destination nodes for edges (2, 1), (3, 2), (4, 3)
dst_ids = torch.tensor([1, 2, 3])
g = dgl.graph((src_ids, dst_ids))
print(g.num_nodes())

返回的节点数为5,如果给定num_nodes<=4,将引发错误。

1.3 dgl.heterograph()

dgl.heterograph( data_dict , num_nodes_dict=None , idtype=None , device=None )

具体例子:

data_dict = {
    ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
    ('user', 'follows', 'topic'): (torch.tensor([1, 1]), torch.tensor([1, 2])),
    ('user', 'plays', 'game'): (torch.tensor([0, 3]), torch.tensor([3, 4]))
}
g = dgl.heterograph(data_dict)
print(g)

上图中,一共有user、topic和game三种类型的节点,他们有三种类型的关系,右边的数据表示边两边节点的索引。

此外,可以显式指定节点个数:

num_nodes_dict = {'user': 4, 'topic': 4, 'game': 6}
g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)

这里一样指定数目不能小于边集合中的最小索引+1。

你可能感兴趣的:(DGL,dgl,pytorch)