官方文档 链接
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=4, shuffle=True)
ENZYMES数据集
获取一个batch
batch = loader.__iter__().next()
print(batch)
# Batch(batch=[169], edge_index=[2, 556], ptr=[5], x=[169, 21], y=[4])
由于batch_size=4
,所以batch中有4个图。batch的属性如图所示:
batch.keys
# ['x', 'edge_index', 'y', 'batch', 'ptr']
batch[0].keys
# ['x', 'edge_index', 'y']
for i in range(batch.num_graphs):
print(batch[i])
"""
Data(edge_index=[2, 178], x=[50, 21], y=[1])
Data(edge_index=[2, 114], x=[30, 21], y=[1])
Data(edge_index=[2, 160], x=[60, 21], y=[1])
Data(edge_index=[2, 104], x=[29, 21], y=[1])
"""
ptr
属性注意ptr
这个属性,如果要把batch
中的4个图取出来需要这个属性。
batch[0]
就是[0:50]
50-0=50batch[1]
就是[50:80]
80-50=30batch[2]
就是[80:140]
140-80=60batch[3]
就是[140:169]
169-140=29batch
属性输出batch
属性查看一下
发现连续50个0,30个1,60个2,29个3
batch
是怎么区分数据包括哪些的batch.__slices__
""
{'y': [0, 1, 2, 3, 4],
'x': [0, 50, 80, 140, 169],
'edge_index': [0, 178, 292, 452, 556]}
""
获取batch[0]
的时候,根据batch.__slices__
batch[0]['y']
= batch['y'][ batch.__slices__['y'][0]:batch.__slices__['y'][0+1] ]
batch[0]['x']
= batch['x'][ batch.__slices__['x'][0]:batch.__slices__['x'][0+1] ]
batch[0]['edge_index']
= batch['edge_index'][ batch.__slices__['edge_index'][0]:batch.__slices__['edge_index'][0+1] ]
获取batch[1]
、batch[2]
、… 、batch[n]
的时候,只用将 0 0 0改为相应的下标即可
https://blog.csdn.net/qq_41795143/article/details/114281387