PyG创建数据集

PyG创建数据集

使用自己的数据集,通过PyG封装的库来转变为Pytorch的数据集


文章目录

  • PyG创建数据集
  • 前言
  • 一、封装的库
  • 二、创建内存数据集
    • 1.一个例子
  • 三、创建较大的数据集
    • 1.例子代码如下:
  • 四、总结
    • 1.常见的问题:
    • 2.尝试思考:


前言

虽然 PyG 已经包含许多有用的数据集,但您可能希望使用自我记录或非公开可用的数据创建自己的数据集。
自己实现数据集很简单,可能只需要查看源代码以了解各种数据集是如何实现的。下面将简要介绍设置您自己的数据集所需的内容。


一、封装的库

PyG为数据集提供了两个抽象类:torch_geometric.data.Dataset 和torch_geometric.data.InMemoryDataset。 InMemoryDataset 继承自 Dataset,如果整个数据集储存在CPU,则应该使用它。

按照 torchvision 约定,每个数据集都存在一个根文件夹,该文件夹指示数据集的存储位置。

我们将根文件夹分成两个文件夹:raw_dir,数据集下载到的位置,以及处理后的数据集保存的位置。

另外,每个数据集都可以传递一个transform、一个pre_transform和一个pre_filter函数,它们默认为None。

  • transform的功能在访问数据之前动态的转换数据对象(用来数据增强)。
  • pre_transform的功能将数据对象保存到磁盘之前应用的转换(因此它最好用于只需要执行一次大量的预计算)。
  • pre_filter的功能可以在保存之前手动过滤掉数据对象。用例可能涉及限制数据对象属于特定类(过滤筛选)。

二、创建内存数据集

为了创建一个 torch_geometric.data.InMemoryDataset,需要实现四个基本方法:

  • torch_geometric.data.InMemoryDataset.raw_file_names(): 为了跳过下载,需要找到 raw_dir 中的文件列表。
  • torch_geometric.data.InMemoryDataset.processed_file_names():为了跳过处理,需要找到process_dir中的文件列表。
  • torch_geometric.data.InMemoryDataset.download():将原始数据下载到 raw_dir。
  • torch_geometric.data.InMemoryDataset.process(): 处理原始数据并将其保存到 processes_dir 中。

process()函数是真正起到主体作用的函数。在这里,我们需要读取并创建一个 Data 对象列表并将其保存到 processes_dir 中。因为保存一个巨大的 python 列表相当慢,我们在保存之前通过 torch_geometric.data.InMemoryDataset.collat​​e() 将列表整理成一个巨大的 Data 对象。整理后的数据对象将所有示例连接到一个大数据对象中,此外,还返回一个切片字典以从该对象重构单个示例。最后,我们需要在构造函数中将这两个对象加载到 self.data 和 self.slices 属性中。

1.一个例子

代码如下(示例):

import torch
from torch_geometric.data import InMemoryDataset, download_url


class MyOwnDataset(InMemoryDataset):## 继承torch_geometric.data.InMemoryDataset父类
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])  ##在构造函数中将这两个对象加载到 self.data 和 self.slices 属性中。

    @property
    def raw_file_names(self): ##如果raw_file中有文件就会跳过下载
        return ['some_file_1', 'some_file_2', ...]

    @property  ##如果processed_file就会跳过处理
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # 下载到`self.raw_dir`。
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # 将数据读入巨大的“数据”列表。
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

三、创建较大的数据集

如果你的内存比较小无法创建内存数据集,可以使用 torch_geometric.data.Dataset,它紧跟 torchvision 数据集的概念。它另外实现以下方法:

  • torch_geometric.data.Dataset.len():返回数据集中样本的数量。
  • torch_geometric.data.Dataset.get():实现加载单个图形的逻辑。

在内部,torch_geometric.data.Dataset.getitem() 从 torch_geometric.data.Dataset.get() 获取数据对象,并可选择根据变换对其进行变换。

1.例子代码如下:

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        idx = 0
        for raw_path in self.raw_paths:
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

在这里,每个图形数据对象都单独保存在 process() 中,并在 get() 中手动加载。


四、总结

1.常见的问题:

  1. 如何跳过 download() 和/或 process() 的执行?

你可以通过不重写download()和process()方法来跳过下载和/或处理。

## 比如,对比上述代码
class MyOwnDataset(Dataset):
    def __init__(self, transform=None, pre_transform=None):
        super().__init__(None, transform, pre_transform)
  1. 我们真的需要使用这些数据集接口吗?

不!就像在常规 PyTorch 中一样,您不必使用数据集,例如,当您想要动态创建合成数据而不将它们显式保存到磁盘时。在这种情况下,只需传递一个包含 torch_geometric.data.Data 对象的常规 python 列表并将它们传递给 torch_geometric.loader.DataLoader:

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)

2.尝试思考:

class MyDataset(InMemoryDataset):
    def __init__(self, root, data_list, transform=None):
        self.data_list = data_list
        super().__init__(root, transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return 'data.pt'

    def process(self):
        torch.save(self.collate(self.data_list), self.processed_paths[0])

1.上述代码中self.processed_paths[0]输出的是什么?
2.collat​​e() 有什么作用?(将torch_geometric.data.Data 对象的 Python 列表整理为 InMemoryDataset 的内部存储格式。)

一般还是手动加载数据集,因为内存有限数据集可能非常大。

你可能感兴趣的:(PyG官方文档学习,神经网络,深度学习,pytorch)