即使PyG已经包含了很多数据集,但是如果大家想使用自己的或非公开数据集,还是需要实现自己的 dataset
。对于数据集的创建涉及到两个类, torch_geometric.data.Dataset
和 torch_geometric.data.InMemoryDataset
,其中第二个是第一个的子类,如果希望全部数据都在内存里则需要使用第二个类。每个数据集需要提供文件夹路径作为参数,其中一个 raw_dir
存储数据集的源文件,而另一个参数 processed_dir
存储处理过的文件。
每个数据集都会经过 transform
,pre_transform
,pre_filter
三个函数,默认是 None
。第一个函数在使用前动态的转化数据对象(所以最好用于数据增强);第二个函数是将数据集存储在磁盘前的转换函数(最好用于仅需做一次的大量预计算任务);最后一个函数在存储前过滤一些对象。
为了创建这个数据集,需要实现下面四个基本方法:
raw_file_names
:raw_dir
文件列表,如果源文件在这里存在的话,就可以跳过下载。
processed_file_names
:在 processed_dir
里的文件列表,用于跳过处理。
download
:将源文件下载到 raw_dir
里面。
process
:处理源数据并保存到 processed_dir
里。在这里面,需要读取并创建一个 Data
对象列表存储到上面的文件夹里,但是python存储是慢的,因此我们在存储前通过 collate
将list合为一个大的 Data
对象,然后从这个对象返回一个 slices
字典用于重构单个样例。最后我们需要加载两个对象 self.data
, self.slices
。
对于其他更高级的方法参考torch_geometric.data
下面是一个简单的例子。
import torch
from torch_geometric.data import InMemoryDataset, download_url
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
# 初始化时将数据读入内存
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# 将源文件下载到`self.raw_dir`.
download_url(url, self.raw_dir)
...
def process(self):
# 读数据到大的 `Data` 列表.
data_list = [...]
if self.pre_filter is not None: # 首先过滤
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
# 将处理后的数据存储
torch.save((data, slices), self.processed_paths[0])
对于更大的不能放在内存中的数据就需要用到 torch_geometric.data.Dataset
,他还需要实现以下方法:
len
:返回数据集中的样本数。
get
:实现读取一个图的逻辑。
还有 __getitem__()
方法从 get()
中获取一个数据对象,并根据 transform
选择性地转化他们。
下面是一个简单的例子:
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
path = download_url(url, self.raw_dir)
...
'''----------------------前面都是一样的---------------------'''
def process(self):
# 这个函数是因为数据比较多,无法一次读入内存,所以以图为单位分开读取、处理、再存储
idx = 0
for raw_path in self.raw_paths:
# 从 `raw_path`读取数据.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
idx += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
return data
如何跳过 download
或 process
的执行?
不对这两个方法进行重载就好了~
class MyOwnDataset(Dataset):
def __init__(self, transform=None, pre_transform=None):
super().__init__(None, transform, pre_transform)
我真的需要使用这些数据集的接口吗?
不需要!仅仅是将Data合并为一个list,将他们传进 DataLoader
即可。
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
CREATING YOUR OWN DATASETS