将数据包装成一个图数据结构(torch_geometric)

import torch
from torch_geometric.data import Data

x = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.float)  # 节点特征矩阵(三个节点,每个节点两个特征)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)  # 边索引矩阵(四条边,每条边包含两个节点索引)
y = torch.tensor([0, 1, 0], dtype=torch.long)  # 每个节点的目标标签

train_mask = torch.tensor([True, False, True])  # 训练掩膜(三个节点)
test_mask = torch.tensor([False, True, False])  # 测试掩膜(三个节点)

data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)
print(data)

你可能感兴趣的:(记录小知识,pytorch,深度学习,人工智能)