使用Pytorch Geometric建立自己的数据集
一.引言
最近在学习GNN的相关知识,在编写代码的过程中不乏会遇到编程问题,学习过GNN的各位应该都知道这个叫PyG(PyTorch Geometric, PYG)的库[1]。这个库是专门用于编写和训练图神经网络(GNN)的库,能够让使用者快速方便地实现GNN网络。然而在训练和使用网络之前,训练数据或预测数据的生成是十分必要的,没有数据的话也无谈训练,因此本文就建立个性化小数据集上为给予大家一个构建的思路。(基于节点级的分类任务)
因为最近需要用到PYG库构建自己的数据集(包含生成.pt文件),因此便上网搜索搭建的方法,发现网上无论是CSDN还是知乎写的相关文章都有点含糊不清,翻看官网文档官网只给了一段代码例子确没有细讲具体如何实现。因此,我只能经过自己爬坑从而搭建了用于图神经网络训练的数据集,希望这篇介绍对大家在学习GNN的道路上有所帮助。废话不多说,直接开始。
二.使用PYG创建自己的小数据集
在搭建自己的数据集上,官网给出了两个例子——1.构建小数据集(内存数据集) 2.构建大数据集(非内存数据集)。小数据集是可以直接读入内存的,而大数据集是需要分批次读入内存(可以选择一次读入的数量)。对于初学者来说一般使用小批量的数据集即可,因此接下来就介绍小批量的数据集构建。首先官网文档给出了一段简化的例子:
import torch
from torch_geometric.data import InMemoryDataset, download_url
#这里给出大家注释方便理解
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
#返回数据集源文件名
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
#返回process方法所需的保存文件名。你之后保存的数据集名字和列表里的一致
@property
def processed_file_names(self):
return ['data.pt']
#用于从网上下载数据集
def download(self):
# Download to `self.raw_dir`.
download_url(url, self.raw_dir)
...
#生成数据集所用的方法
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
首先,想建立自己的数据集最好通过官网的方式即创建一个类来构建数据集,因为这种方法会便捷快速。首先,由于我们不需要下载其他的数据集而是自己从本地构建数据集,因此这里不需要覆盖download的方法把download注释掉。(如果不注释在初始的时候回自动先执行download方法,然后执行processed_file_names方法返回本地的.pt文件并重构torch_geometric.data)。如下:
# #用于从网上下载数据集---这里注释掉
# def download(self):
# # Download to `self.raw_dir`.
# download_url(url, self.raw_dir)
...
#生成数据集所用的方法
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
其次,为了填充data_list里的内容,需要先生成data对象。什么是data对象呢?
在PyG中,单个graph定义为torch_geometric.data.Data实例,默认有以下属性:
,shape为[num_nodes, num_dimensions]。
这些参数都不是必须的,而且,Data类也不仅仅限制于这些参数。
例子[2]:对于一个不带权重的无向图
,有三个节点和四条边,每一个节点的特征维度为1,如下图
注意:这里一条无向边使用两条边(节点对)表示,比如对于节点0和1,无向边表示为(0,1)和(1,0)。
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
>>> Data(edge_index=[2, 4], x=[3, 1])
对于edge_index参数,shape为(2,4),4表示边的条数,对于x参数,shape为(3,1),3表示节点数,1表示节点特征维度。
那么,如果我们想建立一个对于节点级任务的数据集呢?就需要加入标签可以是如下的。
import torch
from torch_geometric.data import Data
#[2,num_edges]从上指向下一一对应。
Edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
#每个节点的特征:从0号节点开始。。
X = torch.tensor([[-1], [0], [1]], dtype=torch.float)
#每个节点的标签:从0号节点开始-两类0,1
Y = torch.tensor([0,1,0],dtye=torch.float)
data = Data(x=x, edge_index=edge_index,y=Y)
这里设置了0号节点标签为0,1号节点标签为1,2号节点标签为0.然后将该生成的data放入data_list即可,
# 生成数据集所用的方法
def process(self):
# Read data into huge `Data` list.
#这里用于构建data
Edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
# 每个节点的特征:从0号节点开始。。
X = torch.tensor([[-1], [0], [1]], dtype=torch.float)
# 每个节点的标签:从0号节点开始-两类0,1
Y = torch.tensor([[0, 0], [1, 1], [2, 0]],dtye=torch.float)
data = Data(x=x, edge_index=edge_index, y=Y)
#放入datalist
data_list = [data]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
最后,我们加入自定义一下.pt文件名并实例化对象,给出数据集所需存储的文件根目录,测试一下就生成了我们想要的数据集。
import torch
from torch_geometric.data import InMemoryDataset, download_url
#这里给出大家注释方便理解
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
#返回数据集源文件名
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
#返回process方法所需的保存文件名。你之后保存的数据集名字和列表里的一致
@property
def processed_file_names(self):
return ['data.pt']
# #用于从网上下载数据集
# def download(self):
# # Download to `self.raw_dir`.
# download_url(url, self.raw_dir)
...
#生成数据集所用的方法
def process(self):
# Read data into huge `Data` list.
# Read data into huge `Data` list.
# 这里用于构建data
Edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
# 每个节点的特征:从0号节点开始。。
X = torch.tensor([[-1], [0], [1]], dtype=torch.float)
# 每个节点的标签:从0号节点开始-两类0,1
Y = torch.tensor([0, 1, 0],dtye=torch.float)
data = Data(x=x, edge_index=edge_index, y=Y)
# 放入datalist
data_list = [data]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
可以看到他会自动在文件根目录下生成三个文件。
最后,需要使用data时直接使用示例化后的对象即可。
"""测试"""
b = MyOwnDataset("MYdata")
>>>Process
b.data.num_features
>>>1
b.data.num_nodes
>>>3
b.data.num_edges
>>>4
三.结尾
希望这篇文章能够帮助到GNN方面的初学者。如果你觉得不错可以收藏~欢迎交流学习。
参考
[1] PyG Documentation — pytorch_geometric 2.0.2 documentation
[2] https://zhuanlan.zhihu.com/p/78452993