PyG官方github
https://github.com/rusty1s/pytorch_geometric
torch_geometric.data
这个模块包含了一个叫Data的类,可以创建Data对象
创建只需要:
节点的属性/特征(the attributes/features associated with each node, node features)
邻接/边连接信息(the connectivity/adjacency of each node, edge index)
假设现在有一个图包含三个节点,每个节点用一个特征向量表示,特征向量分别为f1、f2、f3
x = torch.tensor([f1,f2,f3], dtype=torch.float) #节点的特征向量构成的特征矩阵
y = torch.tensor([0,1,0], dtype=torch.float)#每个节点归属的类别,这里三个节点分别归属于0,1,0类
边集可以被表示为:
边集以COCO格式存储
边集矩阵大小为2*E,E的大小就是有向边的总条数
矩阵的第一行是源节点的标号,第二行是目标节点的标号
edge_index = torch.tensor([[0,1,2,0,3],
[1,0,1,3,2]],dtype=torch.long)
此处存储的边的顺序并不重要
边上的权值(可选参数,非必要):
edge_attr (Tensor, optional): Edge weights or multi-dimensional
edge features. (default: :obj:`None`)
创建Data对象的完整示例:
import torch
from torch_geometric.data import Data
x = torch.tensor([[2,1],[5,6],[3,7],[12,0]],dtype=torch.float)
y = torch.tensor([[0,2,1,0,3],[3,1,0,1,2]],dtype=torch.long)
edge_index = torch.tensor([[0,1,2,0,3],
[1,0,1,3,2]],dtype=torch,long)
data = Data(x=x,y=y,edge_index=edge_index)
有了data对象就可以快速开始了,PyG官方提供了许多图神经网络算法的接口
例如
可根据需要快速开始,示例
from torch_geometric.nn import GCNConv
in_channels=10
out_channels=5
#in_channels (int) – Size of each input sample.
#out_channels (int) – Size of each output sample.
conv1 = GCNConv(_channels, out_channels, cached=True)
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
out=conv1(x, edge_index, edge_weight)
out即为使用GCNConv卷积之后的结果
官方关于创建Dataset的指南
https://github.com/rusty1s/pytorch_geometric/blob/a01dc15d5a879e0054f81f611a0dfb2a68ee9424/docs/source/notes/create_dataset.rst
PyG提供两种不同的数据集类:
1·InMemoryDataset
2·Dataset
可以理解为第一种数据集较小,在内存中可存下。第二种数据集较大,首先介绍第一种也就是InMemoryDataset
Raw_file_names()
它返回一个包含没有处理的数据的名字的list
Processed_file_names()
返回一个包含所有处理过的数据的list。在调用process()这个函数后,通常返回的list只有一个元素,它只保存已经处理过的数据的名字。
Download()
这个函数下载数据到你正在工作的目录中,你可以在self.raw_dir中指定。如果你不需要下载数据,你可以在这函数中写一个pass
Process()
这是Dataset中最重要的函数。你需要整合你的数据成一个包含data的list。然后调用 self.collate()去计算将用DataLodadr的片段
官方示例
import torch
from torch_geometric.data import InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
持续更新PyG相关内容,欢迎关注、留言