参考:torch_geometric.data — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)
数据对象 | 含义 |
---|---|
Data |
A data object describing a homogeneous graph. 同构图 |
HeteroData |
A data object describing a heterogeneous graph, holding multiple node and/or edge types in disjunct storage objects.异构图,并且在分离的存储对象有多节点或者边类型 |
Batch |
A data object describing a batch of graphs as one big (disconnected) graph.使用一个大的不连通图表示图的batch |
TemporalData |
A data object composed by a stream of events describing a temporal graph.由描述时间图的事件流组成的数据对象。 |
Dataset |
Dataset base class for creating graph datasets.基本数据集 |
InMemoryDataset |
Dataset base class for creating graph datasets which easily fit into CPU memory.内存数据集,容易放到内存的图,速度快 |
对象 | 含义 |
---|---|
FeatureStore |
An abstract base class to access features from a remote feature store.从远端特征存储处访问特征 |
GraphStore |
An abstract base class to access edges from a remote graph store.从远端存储处访问边 |
TensorAttr |
Defines the attributes of a FeatureStore tensor.定义FeatureStore的属性 |
EdgeAttr |
Defines the attributes of a GraphStore edge.定义GraphStore edge的属性 |
可以将geometric表示的对象转换成可以进行三种层次的multi-GPU训练的pytorch对象;
对象 | 含义 |
---|---|
LightningDataset |
Converts a set of Dataset objects into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU graph-level training via PyTorch Lightning.转换Dataset对象到LightningDataModule,LightningDataModule可以自动被作为使用pytorch的多GPU 图层次训练的datamodule |
LightningNodeData |
Converts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU node-level training via PyTorch Lightning.转换Data或者HeteroData对象到LightningDataModule,LightningDataModule可以自动被作为使用pytorch的多GPU 节点层次训练的datamodule |
LightningLinkData |
Converts a Data or HeteroData object into a pytorch_lightning.LightningDataModule variant, which can be automatically used as a datamodule for multi-GPU link-level training (such as for link prediction) via PyTorch Lightning.转换Data或者HeteroData对象到LightningDataModule,LightningDataModule可以自动被作为使用pytorch的多GPU 边层次训练的datamodule |
下载和解压函数,很方便使用;
函数 | 含义 |
---|---|
makedirs |
Recursively creates a directory. |
download_url |
Downloads the content of an URL to a specific folder. |
extract_tar |
Extracts a tar archive to a specific folder. |
extract_zip |
Extracts a zip archive to a specific folder. |
extract_bz2 |
Extracts a bz2 archive to a specific folder. |
extract_gz |
Extracts a gz archive to a specific folder. |
torch_geometric.data.Data — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)
from torch_geometric.data import Data
data = Data(x=x, edge_index=edge_index, ...)
# Add additional arguments to `data`:
data.train_idx = torch.tensor([...], dtype=torch.long)
data.test_mask = torch.tensor([...], dtype=torch.bool)
# Analyzing the graph structure:
data.num_nodes
>>> 23
data.is_directed()
>>> False
# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)
torch_geometric.data.HeteroData — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)
注意有多种方法可以创建异构图
方法1:
from torch_geometric.data import HeteroData
data = HeteroData()
# Create two node types "paper" and "author" holding a feature matrix:
data['paper'].x = torch.randn(num_papers, num_paper_features)
data['author'].x = torch.randn(num_authors, num_authors_features)
# Create an edge type "(author, writes, paper)" and building the
# graph connectivity:
data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges]
data['paper'].num_nodes
>>> 23
data['author', 'writes', 'paper'].num_edges
>>> 52
# PyTorch tensor functionality:
data = data.pin_memory()
data = data.to('cuda:0', non_blocking=True)
方法2:
from torch_geometric.data import HeteroData
data = HeteroData()
data['paper'].x = x_paper
data = HeteroData(paper={ 'x': x_paper })
data = HeteroData({'paper': { 'x': x_paper }})
方法3:
data = HeteroData()
data['author', 'writes', 'paper'].edge_index = edge_index_author_paper
data = HeteroData(author__writes__paper={
'edge_index': edge_index_author_paper
})
data = HeteroData({
('author', 'writes', 'paper'):
{ 'edge_index': edge_index_author_paper }
})
DataLoader会返回一个Batch
将一批图描述为一个大的(断开连接的)图。从torch_geometric.data.data或torch_giometric.data.HtereoData继承。此外,可以通过赋值向量批处理来标识单个图形,该批处理将每个节点映射到其各自的图形标识符。从Data或HeteroData对象的Python列表构造Batch对象。赋值向量批是动态创建的。此外,为follow_batch中的每个键创建赋值向量。将排除exclude_keys中给定的任何键。
原理:
神经网络通常以批处理方式进行训练。PyG通过创建稀疏块对角邻接矩阵(由edge_index定义)并在节点维度中连接特征矩阵和目标矩阵,在小批量上实现并行化。这种组合允许在一个批次中的示例上有不同数量的节点和边:
A = [ A 1 ⋱ A n ] , X = [ X 1 ⋮ X n ] , Y = [ Y 1 ⋮ Y n ] \begin{split}\mathbf{A} = \begin{bmatrix} \mathbf{A}_1 & & \\ & \ddots & \\ & & \mathbf{A}_n \end{bmatrix}, \qquad \mathbf{X} = \begin{bmatrix} \mathbf{X}_1 \\ \vdots \\ \mathbf{X}_n \end{bmatrix}, \qquad \mathbf{Y} = \begin{bmatrix} \mathbf{Y}_1 \\ \vdots \\ \mathbf{Y}_n \end{bmatrix}\end{split} A= A1⋱An ,X= X1⋮Xn ,Y= Y1⋮Yn
从Data或HeteroData对象的Python列表构造Batch对象。
def from_data_list(data_list: List[BaseData], follow_batch: Optional[List[str]] = None, exclude_keys: Optional[List[str]] = None)
PARAMETERS
None
)Data
object and returns a transformed version. The data object will be transformed before every access. (default: None
)Data
object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None
)Data
object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None
)True
)基本框架:
基本框架实现的是一种可以使用pre_transform 、pre_filter的方式,可以节约资源。而且父类InMemoryDataset实现__iter__函数,可以通过下标[index]来访问单条数据。
import torch
from torch_geometric.data import InMemoryDataset, download_url
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
download_url(url, self.raw_dir)
...
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
但是对于大的数据集,不能直接放到内存,需要放到磁盘,使用Dataset作为父类
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super().__init__(root, transform, pre_transform, pre_filter)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
path = download_url(url, self.raw_dir)
...
def process(self):
idx = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
idx += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
return data
由描述时间图的事件流组成的数据对象。TemporalData对象可以保存带有结构化消息的事件列表(可以理解为图中的时间边)。事件由源节点、目标节点、时间戳和消息组成。任何连续时间动态图(CTDG)都可以用这四个值表示。
一般来说,TemporalData试图模仿常规Python字典的行为。此外,它为分析图形结构提供了有用的功能,并提供了基本的PyTorch张量功能。
class TemporalData(src: Optional[Tensor] = None, dst: Optional[Tensor] = None, t: Optional[Tensor] = None, msg: Optional[Tensor] = None, **kwargs)[source]
from torch import Tensor
from torch_geometric.data import TemporalData
events = TemporalData(
src=Tensor([1,2,3,4]),
dst=Tensor([2,3,4,5]),
t=Tensor([1000,1010,1100,2000]),
msg=Tensor([1,1,0,0])
)
# Add additional arguments to `events`:
events.y = Tensor([1,1,0,0])
# It is also possible to set additional arguments in the constructor
events = TemporalData(
...,
y=Tensor([1,1,0,0])
)
# Get the number of events:
events.num_events
>>> 4
# Analyzing the graph structure:
events.num_nodes
>>> 5
# PyTorch tensor functionality:
events = events.pin_memory()
events = events.to('cuda:0', non_blocking=True)
将一组数据集对象转换为pytorch_lightning.LightningDataModule变量,该变量可通过PyTorchLightning自动用作多GPU图形级训练的数据模块。LightningDataset将负责通过DataLoader提供小批量。
目前,仅支持pytorch_lightning.strategies.SingleDeviceStrategy和pytorch_lightning.sstrategies.DDP pytorch lightning的SpawnStrategy培训策略,以便在所有设备/进程之间正确共享数据:
import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
devices=4)
trainer.fit(model, datamodule)
将Data或HeteroData对象转换为pytorch_lightning.LightningDataModule变量,该变量可以自动用作通过pytorch lightning进行多GPU节点级训练的数据模块。LightningDataset将负责通过邻居加载器提供小批量。
import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
devices=4)
trainer.fit(model, datamodule)
将Data或HeteroData对象转换为pytorch_lightning.LightningDataModule变量,该变量可通过PyTorchLightning自动用作多GPU链路级训练(如链路预测)的数据模块。LightningDataset将负责通过LinkNeighborLoader提供小批量。
import pytorch_lightning as pl
trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu",
devices=4)
trainer.fit(model, datamodule)