虽然Pytorch-Geometric提供了很多官方数据集,但是当需要构建自己的数据集的时候,就需要对如何使用dataset
基类构造自己的数据集有所了解。库中提供了两个构建数据集的基类:torch_geometric.data.Dataset
和torch_geometric.data.InMemoryDataset
,其中torch_geometric.data.InMemoryDataset
继承了torch_geometric.data.Dataset
,表示是否将整个数据集加载到内存中。
根据torchvision
的习惯,每一个数据集都需要指定一个根目录,根目录下面需要分为两个文件夹,一个是raw_dir
,这个表示下载的原始数据的存放位置,另一个是processed_dir
,表示处理后的数据集存放位置。
另外,每一个数据集函数都可以传递函数transform
,pre_transform
和pre_filter
,默认为None
。transform
函数用于数据对象被加载使用之前进行的动态转换(一般用于数据增强
);pre_transform
函数将数据对象保存到磁盘以前进行的转换,也就是得到processed_dir
内数据文件之前对其调用(一般用于只需要计算一次的复杂预处理过程);pre_filter
函数在数据进行保存之前进行过滤。
构建torch_geometric.data.InMemoryDataset
,需要重写(区分重载和重写)四个函数:
(1)torch_geometric.data.InMemoryDataset.raw_file_names()
存放raw_dir
目录下所有数据文件名的字符串列表,用于下载时的检查过程(正如之前的文章提到的,数据集下载的时候会检测是否已经存在,避免重复下载,也就是如何避免自动下载的httperror
的解决方案)。
(2)torch_geometric.data.InMemoryDataset.processed_file_names()
和(1)类似,存放processed_dir
目录下的文件名的列表,用于检测是否已经存在(不会二次处理)。
(3)torch_geometric.data.InMemoryDataset.download()
下载数据到raw_dir
目录下。
(4)torch_geometric.data.InMemoryDataset.process()
对raw_dir
下的数据进行处理并存储到processed_dir
目录下。
因此,可以发现关键在于第四个函数的实现,函数内首先需要读取原始数据并创建一个torch_geometric.data.Data
对象的列表,并存储到processed_dir
目录下面。直接存储和使用这个python-list
时间代价很高,所以在存储之前调用torch_geometric.data.InMemoryDataset.collate()
函数将列表转换为一个torch_geometric.data.Data
对象。处理后的数据被整合到了一个数据对象中(作为返回值),同时返回一个slices
字典来获取到这个数据对象中单个数据,所以总结下来process
过程一共分四步:
最后在数据类的构造函数中加载数据集并赋值给self.data
和self.slices
。
import torch
from torch_geometric.data import InMemoryDataset
class MyDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
# 数据的下载和处理过程在父类中调用实现
super(MyDataset, self).__init__(root, transform, pre_transform)
# 加载数据
self.data, self.slices = torch.load(self.processed_paths[0])
# 将函数修饰为类属性
@property
def raw_file_names(self):
return ['file_1', 'file_2']
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# download to self.raw_dir
pass
def process(self):
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_filter is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
# 这里的save方式以及路径需要对应构造函数中的load操作
torch.save((data, slices), self.processed_paths[0])
大数据集一般不会直接加载到内存中,这里构建数据集的时候需要继承父类torch_geometric.data.Dataset
。在上面构建数据集时,重写了四个函数,此处还需要多实现两个函数:
(1)torch_geometric.data.Dataset.len()
返回数据集的文件个数。
(2)torch_geometric.data.Dataset.get()
实现对单个数据(图数据集的话一般是单个图)的加载逻辑。
import os.path as osp
import torch
# 这里就不能用InMemoryDataset了
from torch_geometric.data import Dataset
class MyDataset(Dataset):
# 默认预处理函数的参数都是None
def __init__(self, root, transform=None, pre_transform=None):
super(MyDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['file_1', 'file_2']
@property
def processed_file_names(self):
# 一次无法加载所有数据,所以对数据进行了分解
return ['data1.pt', 'data2.pt', 'data3.pt']
def download(self):
# Download to raw_dir
pass
def process(self):
i = 0
# 遍历每一个文件路径
for raw_path in self.raw_paths:
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data{}.pt',format(idx)))
return data
当我第一遍看完文档之后,心中还是存在很多疑惑的,第一,毕竟直接继承了一个父类,具体的流程是如何的,还不清楚,第二,没有亲自制作一个数据集,的确理解上存在模糊,下面对我个人的一些疑惑进行探索。
这里的流程是指包括了个人定义的数据类内部的逻辑以及父类InMemoryDataset
中的逻辑(先分析内存数据集):
1.对MyDataset
实例化,此时调用类内构造函数__init__
,先通过父类构造函数,再从本地加载数据,因此所有的关键操作都是在父类构造中发生的。
2.(在调用父类构造函数的时候,根据文档的官方例子我产生了两个疑惑,第一个是参数中没有传递pre_filter
参数,但是后面为什么还要判断self.pre_filter
,难道说默认的pre_filter
不是None
?而是父类中给了一个实现方式?第二个是参数中传递了transform
,但是在重写的process
函数并没有transform
的过程,那么这个过程又是在哪里实现的呢?)在InMemoryDataset
类中,构造函数为:
def __init__(self, root=None, transform=None, pre_transform=None,
pre_filter=None):
super(InMemoryDataset, self).__init__(root, transform, pre_transform,
pre_filter)
self.data, self.slices = None, None
其中transform
、pre_transform
和pre_filter
都是函数句柄(callable),具体说明如下:
(1)transform
接受参数类型为torch_geometric.data.Data
并且返回一个转换后的版本(数据类型不变),在每一次数据加载到程序之前都会默认调用进行数据转换。
(2)pre_transform
接收参数类型为torch_geometric.data.Data
,返回转换后的版本,在数据被存储到硬盘之前进行转换(只发生一次)。
(3)pre_filter
接受参数类型为torch_geometric.data.Data
,返回布尔类型结果,相当于对原始数据的一个mask
。
可以看到InMemoryDataset
中构造函数的参数,这三个函数参数都是None
。这也就是解决了之前的第一个疑问,如果要用pre_filter
,就必须传递该参数,否则为None
。
3.调用InMemoryDataset
的父类Dataset
的构造函数,其实此处就可以发现大部分的逻辑已经可以在Dataset
类中看到了。先对之前的疑惑二进行解答何时调用transform
,为什么在process
中没有transform
呢?
def __getitem__(self, idx):
r"""Gets the data object at index :obj:`idx` and transforms it (in case
a :obj:`self.transform` is given).
In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
tuple, a LongTensor or a BoolTensor, will return a subset of the
dataset at the specified indices."""
if isinstance(idx, int):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)
这一段代码是源码Dataset
类中的函数,可以看到这个函数是根据索引获取部分数据,idx
为索引目标,可以是列表、元组、LongTensor或者BoolTensor。可以看到只有在访问数据元素时,才会调用transform
函数。
4.在Dataset的构造函数中,有这么几行代码:
if 'download' in self.__class__.__dict__.keys():
self._download()
if 'process' in self.__class__.__dict__.keys():
self._process()
此处调用下载函数和处理函数,而self._download()
会调用self.download()
,process
同理。
5.将处理好的数据存储到本地,然后再加载到程序中。
以上就是详细的处理流程了,值得注意的是,如果需要下载数据,利用request相关技术,需要自己重写download()
函数;如果要对数据进行预过滤、转换和预转换,需要定义外部函数作为参数传递给构造过程。
看了上面的内容,可能还是不知道咋做,现在就通过官方数据集的源码进行一波分析。例子以Planetoid
为例:
from torch_geometric.datasets import Planetoid
1.构造函数中transform
和pre_transform
都设置了None
,但是没有pre_filter
参数,也就是说这里不允许传递pre_filter
参数。
def __init__(self, root, name, transform=None, pre_transform=None):
self.name = name
super(Planetoid, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
该数据集只有一个数据文件,所以直接取索引0。
2. 下载函数如下:
def download(self):
for name in self.raw_file_names:
download_url('{}/{}'.format(self.url, name), self.raw_dir)
遍历每一个文件名,然后调用download_url
函数进行下载。
from torch_geometric.data import download_url
不过在download_url
和Dataset
类中的_download
函数中都进行防覆盖检测。
3.处理函数如下:
def process(self):
data = read_planetoid_data(self.raw_dir, self.name)
data = data if self.pre_transform is None else self.pre_transform(data)
torch.save(self.collate([data]), self.processed_paths[0])
第一步读取数据,第二步转换,第三步存储,主要是第一步的操作,这里调用了一个函数read_planetoid_data
,此函数读取本地文件后,进行了训练集、测试集、验证集的划分,并且构造了一个Data对象:
data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask
在存储之前调用了
self.collate([data])
该函数的具体内容在下一小节中讲解。
collate
函数在InMemoryDataset
中实现,将一个python列表形式数据转换(每一个元素都是一个数据对象)为torch_geometric.data.InMemoryDataset
内部存储数据的格式。这里每一个数据对象未必是Data类型(一般代表一个Graph),也可以是其他的,比如图片等。
data = data_list[0].__class__()
这一行代码可以对第一个元素的类名解析并重新构造一个同类型元素。
for item, key in product(data_list, keys):
data[key].append(item[key])
利用笛卡尔积构造元组替代双层循环,并且将列表中所有数据元素的值存放到一个数据对象中。后面的代码进行了一些拼接过程,具体的见Github。