在PyG中,除了直接使用它自带的benchmark数据集外,用户还可以自定义数据集,其方式与Pytorch类似,需要继承数据集类。PyG中提供了两个数据集抽象类:
torch_geometric.data.Dataset
:用于构建大型数据集(非内存数据集);torch_geometric.data.InMemoryDataset
:用于构建内存数据集(小数据集),继承自Dataset
。下面是对其的详细介绍。
在PyG中要构建自己的内存数据集需要先继承InMemoryDataset
类,并实现如下方法:
raw_file_names()
:返回原始数据集的文件名列表,若self.raw_dir
中没有该列表中的文件,则会通过download()
进行下载;processed_file_names()
:返回process()
方法处理后的文件名列表,若self.processed_dir
中没有确实该列表中的文件,则需要通过process()
方法进行处理;download()
:下载原始数据集到self.raw_dir
中;process()
:处理原始数据集,并保存到processed_dir
中。在前两个方法中,若只有单个文件,则直接返回文件字符串即可,不一定要返回list对象。
另外,上面的self.raw_dir
和self.processed_dir
其实是两个方法,其源码为:
# 加上@property,可以使得方法像属性一样被调用
@property
def raw_dir(self) -> str:
return osp.join(self.root, 'raw')
@property
def processed_dir(self) -> str:
return osp.join(self.root, 'processed')
从源码可以看出,self.raw_dir
和self.processed_dir
是给定保存路径root
下的原始数据文件夹和处理后的数据文件夹的路径。
本文以SNAP数据集中的一个社交网络Facebook为例,来演示如何创建一个InMemoryDataset
数据集FaceBook
,该数据集包含4039个节点、88234条边。利用Gephi对该网络进行可视化如下:
根据3.1节中的说明,下面是自定义FaceBook
类的源码:
import os
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset, download_url, extract_gz
class FaceBook(InMemoryDataset):
url = "https://snap.stanford.edu/data/facebook_combined.txt.gz"
def __init__(self,
root,
transform=None,
pre_transform=None,
pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ["facebook_combined.txt"]
@property
def processed_file_names(self):
return "data.pt"
def download(self):
path = download_url(self.url, self.raw_dir)
extract_gz(path, self.raw_dir)
def process(self):
# 加载原始数据文件
path = os.path.join(self.raw_dir, "facebook_combined.txt")
edges = pd.read_csv(path, header=None,
delimiter=" ").values.reshape(2, -1)
# 构建Data对象
edge_index = torch.from_numpy(edges)
g = Data(edge_index=edge_index, num_nodes=4039)
data, slices = self.collate([g])
torch.save((data, slices), self.processed_paths[0])
if __name__ == "__main__":
dataset = FaceBook(root="tmp")
data = dataset[0]
print(data.num_edges, data.num_nodes)
# 88234 4039
需要注意的是
download
和process
只在第一次调用时会调用,之后会直接加载处理好的数据集。download()
函数来下载原始数据集。对于大型图数据集,需要继承Dataset
类,除了InMemoryDataset
中需要重写的4个方法外,还需重写如下方法:
len()
: 返回数据集中实例的数量;get()
:加载单个图的逻辑。由于自定义大型数据集与InMemoryDataset
类似,具体演示略。
参考资料:
自定义数据集是一项重要的事情,尤其是当你本地有些数据需要转换为PyG中标准的图数据集的时候。