pytorch Geometric Data使用邻接表去表示图,同时也表示了node特征x, 边属性edge_attr等, 需要注意的是, Data只表示一张图(single graph)
Data(x=None, edge_index=None, edge_attr=None, y=None)
x: 表示节点特征,可选,shape: [num_nodes, num_node_features] 有的图只有结构没有节点特征
edge_index: 表示边,也就是邻接表, shape: [2, num_edges]
注意,因为能表示有向图, 对于无向图,一条边要存入两次,也就是位于节点1和节点2的边,需要写成[[1,2][2,1]]而不能只写入[[1],[2]]; node的编号和edge要对应,也就是 max_num_edges = num_nodesnum_nodes 而不是num_nodesnum_nodes /2
edge_attr: 表示边属性(e.g. , 权重,类型),shape: [num_edges, num_edge_features]
y: 是label,官方文档中说 Graph or node targets with arbitrary shape,所以shape可以是[num_nodes, nodes_label_dimension],或者是[graph_label_dimesnion]
pytorch geometric 构建数据集分两种
import torch
from torch_geometric.data import InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
注意
1.如果需要在initial里面初始化一些参数,如定义mask,需要在super前继承参数
self.num_train_per_class 要放到super(NodeDatasetInMem, self)这一行前面
2.我们主要需要编辑def processed_file_names(self) 和 def process(self),
processed_file_names只需要申明把处理好的dataset存在哪里(路径加文件名)
process就是写一个函数,处理数据成torch_geometric.data.Data的形式,如果是图分类,还需要把多个图存成一个list
要注意x一般是float tensor, y 是 long tensor, mask 是boolean tensor, edge_index是long tensor
而且当y是graph label时, 不能是0-dimension tensor, 也就是说
y = torch.tensor(0, dtype=torch.long)#错
y = torch.tensor([0], dtype=torch.long)#对
3.其余函数作用
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
这个是官方代码里面的,作用就是通过self.collate把数据划分成不同slices去保存读取 (大数据块切成小块)
所以即使只有一个graph写成了data, 在调用self.collate时,也要写成list:
data, slices = self.collate([data])
torch_geometric.data.Dataset.len():
因为Dataset相对于InMemoryDataset,不会一次加载所有函数,而是分批,所有会把数据保存成好几个小数据包(.pt 文件),len() 就是说明有几个数据包,官方的指导:
def len(self):
return len(self.processed_file_names)
可以完全照搬,只需要改变processed_file_names的返回值,例如
还有一个get() 函数
torch_geometric.data.Dataset.get():
这个函数需要返回值时一个data,single graph: Implements the logic to load a single graph
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
注意, 这里的load里面的函数名要和processed_file_name()返回的函数名一致, idx就是数据包的遍历下标
几个容易出问题的地方
参考:
[1]: https://www.jianshu.com/p/6b9dccbceae4reference