从代码角度深入浅出图神经网络系列笔记(三)

文章目录

  • 前言
  • 构建数据集
    • 1、Dataset
    • 2、官方文档例子
    • 3、process解读
  • MINI-BATCHING

前言

这一节笔记中主要针对继承Dataset,分次加载到内存,这种数据集一般很大,不适合一次性加载完毕,需要分批加载处理。

构建数据集

1、Dataset

pytorch geometric 构建数据集分两种:
1、继承InMemoryDataset,一次性加载所有的数据到内存
2、继承Dataset,分次加载到内存

Mini-Batching:将一组样本组合成一个统一的表示形式,进行并行处理

2、官方文档例子

从代码角度深入浅出图神经网络系列笔记(三)_第1张图片
首先还是看下引入的库文件,对比一下InMemoryDataset,这里我们引入的是Dataset,对比一下这两个库,初始化的参数完全一致
从代码角度深入浅出图神经网络系列笔记(三)_第2张图片
主要是Dataset多了len()get()

  • torch_geometric.data.Dataset.len(): 返回数据集中的样本数

  • torch_geometric.data.Dataset.get(): 实现加载单个图的逻辑

下面来看对比分析:

import os.path as osp # 调用系统路径

import torch
from torch_geometric.data import Dataset


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)# 对比InMemoryDataset 这块少了直接加载的代码,因为我们需要分次加载

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self): # 对比一次性加载,这里会有多个
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self): # 返回数目
        return len(self.processed_file_names)

    def get(self, idx): # 一个一个文件手动加载到内存中
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

3、process解读

下面是我从某博客中引用的一段话,个人觉得说的挺好的。process()方法存在的意义是:

  1. 原始的格式可能是 csv 或者 mat,在process()函数里可以转化为 pt 格式的文件。
  2. 这样在get()方法中,就可以直接使用torch.load()函数,读取 pt 格式的文件,返回的是torch_geometric.data.Data类型的数据。
  3. 而不用在get()方法里面做数据转换操作,比如说,把其他格式的数据转换为 torch_geometric.data.Data类型的数据。
  4. 当然我们也可以提前把数据转换为 torch_geometric.data.Data类型,使用 pt 格式保存在self.processed_dir中。

Ps:上面这段话部分针对的是Dataset,其实InMemoryDataset也差不多,只不过最后不需要逐个加载进内存,而是直接加载进内存

MINI-BATCHING

官方文档地址:

https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html#pairs-of-graphs

我觉得这里我没有理解,视频的解释实在是太少了,之后再补

你可能感兴趣的:(从代码角度深入浅出图神经网络,pytorch,图神经网络,代码)