yelp数据集_pyg学习04:数据集创建

引言

在pyg的torch_geometric.datasets的包中,已经包含许多常见的数据集,但是针对的自己的需求去构建或者引用其他的一些数据集的时候,我们需要在pyg提供的函数的基础上进行数据的规范化。

在pyg中,可以构建两种类型的数据集,一种是In Memory Dataset,另一种是Larger Dataset。前者需要引入的包torch_geometric.data.InMemoryDataset,适用于小数据集,直接全部加载至内存;后者需要引入torch_geometric.data.Dataset,适用于分批大数据。需要注意的是前者是继承自后者的。

在创建数据集之前,有几点需要注意

  1. 每个数据集的根目录分成raw_dir和processed_dir,前者是下载的原始文件需要存储的地方;后者是处理后的数据存储的地方。
  2. 每个数据集可以经过transforma pre_transforma pre_filter函数,默认是None,这个在介绍之前例子的时候说过了,为了方便阅读,这里重述一遍。

yelp数据集_pyg学习04:数据集创建_第1张图片

创建"In Memory Datasets"

四个基本函数

  1. torch_geometric.data.InMemoryDataset.raw_file_names(): 返回一个文件列表,包含raw_dir中的文件目录。可以根据此列表来决定哪些需要下载或者已下载的直接跳过。
  2. torch_geometric.data.InMemoryDataset.processed_file_names():
    返回一个处理后的文件列表,包含processed_dir中的文件目录。据此来决定需要跳过。也就说,在你处理完后,你再次运行该程序将不会二次处理。
  3. torch_geometric.data.InMemoryDataset.download():
    将原始数据下载到 raw_dir 文件夹. 4. torch_geometric.data.InMemoryDataset.process():
    处理原始数据将结果存放至 processed_dir 文件夹. 注意,这里需要将结果存储成Data格式。为解决python处理达标存储慢的的问题,通过torch_geometric.data.InMemoryDataset.collate()将许多Data列表整理成一个很大的Data对象,并且返回一个slices索引字典,因此我们需要设置self.dataself.slice这两个属性。

简单数据集搭建

具体代码如下

import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import download_url
import os
from torch_geometric.io import read_planetoid_data
from torch_geometric.datasets import Planetoid

# data=Planetoid(name='Cora',root='data')
class SimpleExample(InMemoryDataset):
    #这里参考InMemoryDataset类,这里transform和filter都没用到
    def __init__(self,url= 'https://github.com/kimiyoung/planetoid/raw/master/data', dataname='cora',root='dataset', transform=None, pre_transform=None,pre_filter=None):

        self.url=url
        self.dataname=dataname
        self.transform=transform
        self.pre_filter=pre_filter
        self.pre_transform=pre_transform

        self.raw=os.path.join(root,dataname,'raw')
        self.processed=os.path.join(root,dataname,'processed')
        super(SimpleExample,self).__init__(root=root,transform=transform,pre_transform=pre_transform,pre_filter=pre_filter)
        #其中processed_paths来自于Dataset类,返回数据
        self.x, self.slices = torch.load(self.processed_paths[0])
    #接下来写好四个函数,其中前两个是属性获取,所以这里采用property修饰器

    #返回原始文件列表
    @property
    def raw_file_names(self):
        names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
        return ['ind.{}.{}'.format(self.dataname.lower(), name) for name in names]
    #返回需要跳过的文件列表
    @property
    def processed_file_names(self):
        return ['data.pt']

    #下载原生文件
    def download(self):
        for name in self.raw_file_names:
            download_url('{}/{}'.format(self.url, name), self.raw)

    def process(self):
        data=read_planetoid_data(self.raw,self.dataname)
        data_list = [data]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)

        torch.save((data, slices), self.processed_paths[0])
    #显示属性
    def __repr__(self):
        return '{}()'.format(self.dataname)

data=SimpleExample()
print(data[0])
print(data.processed_file_names)

具体的执行步骤如下图红框所示,在Dataset类中可以找到

yelp数据集_pyg学习04:数据集创建_第2张图片

由此可见,程序默认先执行下载操作,然后再执行处理操作。

创建大数据集

三点注意

在创建大数据集的时候,需要注意的节点是 1. 继承自Dataset类 2. 自己实现len()方法,该方法返回你数据集的长度 3. 自己实现get()方法,实现载入单个图的逻辑。

简单数据集搭建

因为暂时用不到大的数据集,所以暂时先不处理,要是有兴趣,可以参考下面的代码

https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/yelp.html#Yelp

最后,有兴趣的朋友,可以加入群:777486287,方便大家交流探讨

你可能感兴趣的:(yelp数据集)