动手学PyG(二)PyG中包含的常用Benchmark数据集

PyG中常用的Benchmark数据集


本文主要参考了 PyG英文文档

PyG中包含了大量常用的benchmark数据集,如所有Planetoid数据集(Cora, Citeseer, Pubmed)。来自http://graphkernels.cs.tu-dortmund.de/的所有图分类数据集,和其简洁版,QM7和QM9数据集等等。

数据集的下载和使用非常简单,下面以ENZYMES数据集(包含600个图,分为6类)为例做介绍。

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
>>> ENZYMES(600)

len(dataset)
>>> 600

dataset.num_classes
>>> 6

dataset.num_node_features
>>> 3

接下来我们就可以访问每一张图了:

data = dataset[0]
>>> Data(edge_index=[2, 168], x=[37, 3], y[1])

data.is_undirected()
>>> True

我们可以发现数据库中的第一张图包含37个节点,每个节点有3个属性。共有168/2=84个无向连边,该图被划分为一个种类。
我们同样可以用列表的切片操作来划分训练/校验/测试集:

train_dataset = dataset[:540]
>>> ENZYMES(540)

test_dataset = dataset[540:]
>>> ENZYMES(60)

如果你不确定这些图有没有被打乱过,你可以执行以下代码:

dataset = dataset.shuffle()
>>> ENZYMES(600)

接下来让我们试一下Cora数据集,该数据集常用于本监督节点分类任务的测试。

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='/tmp/Cora', name='Cora')
>>> Cora()

len(dataset)
>>> 1

dataset.num_classes
>>> 7

dataset.num_node_features
>>> 1433

Cora数据集中只有一张无向引用图。

data = dataset[0]
>>> Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2709], val_mask=[2708], x=[2708, 1433], y =[2708])

data.is_undirected()
>>>140

data.train_mask.sum().item()
>>> 140

data.val_mask.sum().item()
>>> 500

data.test_mask.sum().item()
>>> 1000

Cora数据集有三个多余的node-level的属性:train_mask, val_mask, test_mask。其中:

  • train_mask 表示那些节点用于训练。
  • val_mask 表示哪些节点用于校验。
  • test_mask 表示哪些节点用于测试。

你可能感兴趣的:(动手学PyG,pytorch,深度学习,python)