DGL在 dgl.data 里实现了很多常用的图数据集。它们遵循了由 dgl.data.DGLDataset 类定义的标准的数据处理管道。DGL推荐用户将图数据处理为 dgl.data.DGLDataset 的子类。该类为导入、处理和保存图数据提供了简单而清晰的解决方案。
通过调用“has_cache()”判断磁盘上是否有已经处理好的数据集缓存。如果有,则跳转到第5步,直接加载数据集;
调用“download()”下载数据;
调用“process()”处理数据;
调用“save()”保存处理好的数据到磁盘,跳转到第6步;
调用“load()”从磁盘加载数据集;
完成。
下面给出了一个继承自DGLDataset类的例子。子类中必须实现process(), getitem(idx) 和 len()。同时官方建议也实现save()和load(),避免对大型数据集的重复处理。
from dgl.data import DGLDataset
class MyDataset(DGLDataset):
""" 用于在DGL中自定义图数据集的模板:
Parameters
----------
url : str
下载原始数据集的url。
raw_dir : str
指定下载数据的存储目录或已下载数据的存储目录。默认: ~/.dgl/
save_dir : str
处理完成的数据集的保存目录。默认:raw_dir指定的值
force_reload : bool
是否重新导入数据集。默认:False
verbose : bool
是否打印进度信息。
"""
def __init__(self,
url=None,
raw_dir=None,
save_dir=None,
force_reload=False,
verbose=False):
super(MyDataset, self).__init__(name='dataset_name',
url=url,
raw_dir=raw_dir,
save_dir=save_dir,
force_reload=force_reload,
verbose=verbose)
def download(self):
# 将原始数据下载到本地磁盘
pass
def process(self):
# 将原始数据处理为图、标签和数据集划分的掩码
pass
def __getitem__(self, idx):
# 通过idx得到与之对应的一个样本
pass
def __len__(self):
# 数据样本的数量
pass
def save(self):
# 将处理后的数据保存至 `self.save_path`
pass
def load(self):
# 从 `self.save_path` 导入处理后的数据
pass
def has_cache(self):
# 检查在 `self.save_path` 中是否存有处理后的数据
pass
这一段就是给实现“download()”举了两个例子。
从“self.url”链接下载到“self.raw_dir”目录下,保存为“self.name+格式后缀”:
import os
from dgl.data.utils import download
def download(self):
# 存储文件的路径
file_path = os.path.join(self.raw_dir, self.name + '.mat')
# 下载文件
download(self.url, path=file_path)
如果数据集是一个zip文件,可以直接继承 dgl.data.DGLBuiltinDataset 类,其支持解压缩zip文件。
如果文件是.gz、.tar、.tar.gz或.tgz文件,下载后需要用 extract_archive() 函数进行解压缩:
from dgl.data.utils import download, check_sha1
def download(self):
# 存储文件的路径,请确保使用与原始文件名相同的后缀
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
# 下载文件
download(self.url, path=gz_file_path)
# 检查 SHA-1
if not check_sha1(gz_file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
# 将文件解压缩到目录self.raw_dir下的self.name目录中
self._extract_gz(gz_file_path, self.raw_path)
假设数据已经下载到“self.raw_dir”目录下,接下来就可以处理数据了。根据图上的任务,分别从整图分类、节点分类和链接预测介绍。
整图分类任务与传统机器学习任务类似,整图为特征,类别为标签。调用“process()”将数据集处理为 dgl.DGLGraph 对象的列表和标签张量的列表。
class QM7bDataset(DGLDataset):
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
mat_path = self.raw_path + '.mat'
self.graphs, self.label = self._load_graph(mat_path)
def _load_graph(self, filename):
data = io.loadmat(filename)
labels = F.tensor(data['T'], dtype=F.data_type_dict['float32'])
feats = data['X']
num_graphs = labels.shape[0]
graphs = []
for i in range(num_graphs):
edge_list = feats[i].nonzero()
g = dgl_graph(edge_list)
g.edata['h'] = F.tensor(feats[i][edge_list[0], edge_list[1]].reshape(-1, 1),
dtype=F.data_type_dict['float32'])
graphs.append(g)
return graphs, labels
def save(self):
"""save the graph list and the labels"""
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
save_graphs(str(graph_path), self.graphs, {'labels': self.label})
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph.bin')
return os.path.exists(graph_path)
def load(self):
graphs, label_dict = load_graphs(os.path.join(self.save_path, 'dgl_graph.bin'))
self.graphs = graphs
self.label = label_dict['labels']
def download(self):
file_path = os.path.join(self.raw_dir, self.name + '.mat')
download(self.url, path=file_path)
if not check_sha1(file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name))
@property
def num_labels(self):
return 14
def __getitem__(self, idx):
return self.graphs[idx], self.label[idx]
def __len__(self):
return len(self.graphs)
处理完数据后,就可以跟传统分类任务一样使用数据了。
import dgl
import torch
from torch.utils.data import DataLoader
# 数据导入
dataset = QM7bDataset()
num_labels = dataset.num_labels
# 创建collate_fn函数
def _collate_fn(batch):
graphs, labels = batch
g = dgl.batch(graphs)
labels = torch.tensor(labels, dtype=torch.long)
return g, labels
# 创建 dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn)
# 训练
for epoch in range(100):
for g, labels in dataloader:
# 用户自己的训练代码
pass
与整图分类不同,节点分类通常在单个图上进行。因此数据集的划分是在图的节点集上进行。 DGL建议使用节点掩码来指定数据集的划分,相当于给节点做一个标记,明确是为训练节点(“g.ndata[‘train_mask’]”)、验证节点(“g.ndata[‘val_mask’]”)还是测试节点(“g.ndata[‘test_mask’]”)。 本节以内置数据集 CitationGraphDataset 为例,支持’cora’, ‘citeseer’, 'pubmed’三个常用的数据集,DGL已经分别针对三个数据集构建了子类CoraGraphDataset、CiteseerGraphDataset和PubmedGraphDataset。
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url, generate_mask_tensor
class CitationGraphDataset(DGLBuiltinDataset):
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
if name.lower() == 'cora':
name = 'cora_v2'
url = _get_dgl_url(self._urls[name])
super(CitationGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# 跳过一些处理的代码
# === 跳过数据处理 ===
# 构建图
g = dgl.graph(graph)
# 划分掩码
g.ndata['train_mask'] = generate_mask_tensor(train_mask)
g.ndata['val_mask'] = generate_mask_tensor(val_mask)
g.ndata['test_mask'] = generate_mask_tensor(test_mask)
# 节点的标签
g.ndata['label'] = torch.tensor(labels)
# 节点的特征
g.ndata['feat'] = torch.tensor(_preprocess_features(features),
dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
self._labels = labels
self._g = g
def __getitem__(self, idx):
assert idx == 0, "这个数据集里只有一个图"
return self._g
def __len__(self):
return 1
由于数据集只有一个图,所以需要取第0个元素“dataset[0]”:
# 创建链接预测数据集示例
class KnowledgeGraphDataset(DGLBuiltinDataset):
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
self._name = name
self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
super(KnowledgeGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# 跳过一些处理的代码
# === 跳过数据处理 ===
# 划分掩码
g.edata['train_mask'] = train_mask
g.edata['val_mask'] = val_mask
g.edata['test_mask'] = test_mask
# 边类型
g.edata['etype'] = etype
# 节点类型
g.ndata['ntype'] = ntype
self._g = g
def __getitem__(self, idx):
assert idx == 0, "这个数据集只有一个图"
return self._g
def __len__(self):
return 1
下面利用’FB15k-237’对应的子类 dgl.data.FB15k237Dataset 来做演示如何使用用于链路预测的数据集:
from dgl.data import FB15k237Dataset
# 导入数据
dataset = FB15k237Dataset()
graph = dataset[0]
# 获取训练集掩码
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask).squeeze()
src, dst = graph.edges(train_idx)
# 获取训练集中的边类型
rel = graph.edata['etype'][train_idx]
DGL提供了4个函数:
dgl.save_graphs(): 保存DGLGraph对象和标签到本地磁盘
dgl.load_graphs():从本地磁盘读取它们
dgl.data.utils.save_info(): 将数据集的有用信息(python dict对象)保存到本地磁盘
dgl.data.utils.load_info()和从本地磁盘读取它们
import os
from dgl import save_graphs, load_graphs
from dgl.data.utils import makedirs, save_info, load_info
def save(self):
# 保存图和标签
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
save_graphs(graph_path, self.graphs, {'labels': self.labels})
# 在Python字典里保存其他信息
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
save_info(info_path, {'num_classes': self.num_classes})
def load(self):
# 从目录 `self.save_path` 里读取处理过的数据
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
self.graphs, label_dict = load_graphs(graph_path)
self.labels = label_dict['labels']
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
self.num_classes = load_info(info_path)['num_classes']
def has_cache(self):
# 检查在 `self.save_path` 里是否有处理过的数据文件
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
return os.path.exists(graph_path) and os.path.exists(info_path)
当处理过的数据比较大时,在 getitem(idx) 中处理每个数据实例是更高效的方法。
OGB(Open Graph Benchmark)是一个图深度学习的基准数据集。 官方的 ogb 包提供了用于下载和处理OGB数据集到 dgl.data.DGLGraph 对象的API。
首先需要使用“pip install ogb”安装这个包,接着就可以根据任务从里面加载数据集了。
类的命名十分统一,只需要执行“dataset = DglGraphPropPredDataset(name=‘ogbg-molhiv’)”即可得到相应的数据集,然后与传统机器学习任务类似,将数据处理为(graph, label)的形式。
# 载入OGB的Graph Property Prediction数据集
import dgl
import torch
from ogb.graphproppred import DglGraphPropPredDataset
from torch.utils.data import DataLoader
def _collate_fn(batch):
# 小批次是一个元组(graph, label)列表
graphs = [e[0] for e in batch]
g = dgl.batch(graphs)
labels = [e[1] for e in batch]
labels = torch.stack(labels, 0)
return g, labels
# 载入数据集
dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()
# dataloader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
类似地,执行“dataset = DglNodePropPredDataset(name=‘ogbn-proteins’)”即可获取数据集,这种数据集只有一个图对象。
# 载入OGB的Node Property Prediction数据集
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset(name='ogbn-proteins')
split_idx = dataset.get_idx_split()
# there is only one graph in Node Property Prediction datasets
# 在Node Property Prediction数据集里只有一个图
g, labels = dataset[0]
# 获取划分的标签
train_label = dataset.labels[split_idx['train']]
valid_label = dataset.labels[split_idx['valid']]
test_label = dataset.labels[split_idx['test']]
通过执行“dataset = DglLinkPropPredDataset(name=‘ogbl-ppa’)”获取数据集,同样是单图。
# 载入OGB的Link Property Prediction数据集
from ogb.linkproppred import DglLinkPropPredDataset
dataset = DglLinkPropPredDataset(name='ogbl-ppa')
split_edge = dataset.get_edge_split()
graph = dataset[0]
print(split_edge['train'].keys())
print(split_edge['valid'].keys())
print(split_edge['test'].keys())