PyG|邻接矩阵转为coo_matrix格式

在使用PyG框架的时候,PyG要求输入的是 edge_index 格式,而不是我们所使用的邻接矩阵格式,即N x N。

import scipy.sparse as sp
import numpy as np
import torch

# adj_matrix 是邻接矩阵
tmp_coo = sp.coo_matrix(adj_matrix)
values = tmp_coo.data
indices = np.vstack((tmp_coo.row,tmp_coo.col))
i = torch.LongTensor(indices)
v = torch.LongTensor(values)
edge_index=torch.sparse_coo_tensor(i,v,tmp_coo.shape)

你可能感兴趣的:(图神经网络,深度学习,复杂网络,pytorch,深度学习,机器学习)