使用自己的数据集,通过PyG封装的库来转变为Pytorch的数据集
虽然 PyG 已经包含许多有用的数据集,但您可能希望使用自我记录或非公开可用的数据创建自己的数据集。
自己实现数据集很简单,可能只需要查看源代码以了解各种数据集是如何实现的。下面将简要介绍设置您自己的数据集所需的内容。
PyG为数据集提供了两个抽象类:torch_geometric.data.Dataset 和torch_geometric.data.InMemoryDataset。 InMemoryDataset 继承自 Dataset,如果整个数据集储存在CPU,则应该使用它。
按照 torchvision 约定,每个数据集都存在一个根文件夹,该文件夹指示数据集的存储位置。
我们将根文件夹分成两个文件夹:raw_dir,数据集下载到的位置,以及处理后的数据集保存的位置。
另外,每个数据集都可以传递一个transform、一个pre_transform和一个pre_filter函数,它们默认为None。
为了创建一个 torch_geometric.data.InMemoryDataset,需要实现四个基本方法:
process()函数是真正起到主体作用的函数。在这里,我们需要读取并创建一个 Data 对象列表并将其保存到 processes_dir 中。因为保存一个巨大的 python 列表相当慢,我们在保存之前通过 torch_geometric.data.InMemoryDataset.collate() 将列表整理成一个巨大的 Data 对象。整理后的数据对象将所有示例连接到一个大数据对象中,此外,还返回一个切片字典以从该对象重构单个示例。最后,我们需要在构造函数中将这两个对象加载到 self.data 和 self.slices 属性中。
代码如下(示例):
import torch
from torch_geometric.data import InMemoryDataset, download_url
class MyOwnDataset(InMemoryDataset):## 继承torch_geometric.data.InMemoryDataset父类
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0]) ##在构造函数中将这两个对象加载到 self.data 和 self.slices 属性中。
@property
def raw_file_names(self): ##如果raw_file中有文件就会跳过下载
return ['some_file_1', 'some_file_2', ...]
@property ##如果processed_file就会跳过处理
def processed_file_names(self):
return ['data.pt']
def download(self):
# 下载到`self.raw_dir`。
download_url(url, self.raw_dir)
...
def process(self):
# 将数据读入巨大的“数据”列表。
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
如果你的内存比较小无法创建内存数据集,可以使用 torch_geometric.data.Dataset,它紧跟 torchvision 数据集的概念。它另外实现以下方法:
在内部,torch_geometric.data.Dataset.getitem() 从 torch_geometric.data.Dataset.get() 获取数据对象,并可选择根据变换对其进行变换。
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
path = download_url(url, self.raw_dir)
...
def process(self):
idx = 0
for raw_path in self.raw_paths:
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
idx += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
return data
在这里,每个图形数据对象都单独保存在 process() 中,并在 get() 中手动加载。
你可以通过不重写download()和process()方法来跳过下载和/或处理。
## 比如,对比上述代码
class MyOwnDataset(Dataset):
def __init__(self, transform=None, pre_transform=None):
super().__init__(None, transform, pre_transform)
不!就像在常规 PyTorch 中一样,您不必使用数据集,例如,当您想要动态创建合成数据而不将它们显式保存到磁盘时。在这种情况下,只需传递一个包含 torch_geometric.data.Data 对象的常规 python 列表并将它们传递给 torch_geometric.loader.DataLoader:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
class MyDataset(InMemoryDataset):
def __init__(self, root, data_list, transform=None):
self.data_list = data_list
super().__init__(root, transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def processed_file_names(self):
return 'data.pt'
def process(self):
torch.save(self.collate(self.data_list), self.processed_paths[0])
1.上述代码中self.processed_paths[0]输出的是什么?
2.collate() 有什么作用?(将torch_geometric.data.Data 对象的 Python 列表整理为 InMemoryDataset 的内部存储格式。)
一般还是手动加载数据集,因为内存有限数据集可能非常大。