PyG的edge index形式是 [ ( n o d e 1 , n o d e 2 ) , ( n o d e 1 , n o d e 3 ) . . . ] [(node_1,node_2), (node_1, node_3)...] [(node1,node2),(node1,node3)...]这种edge pair。
直接for循环,吧edge index里面的位置填充1:
import torch
def edge_index_to_adjacency_matrix(edge_index, num_nodes):
# 创建大小为 (num_nodes, num_nodes) 的二维张量
adjacency_matrix = torch.zeros(num_nodes, num_nodes)
# 根据边索引填充邻接矩阵的元素
for i, j in zip(*edge_index):
adjacency_matrix[i, j] = 1
adjacency_matrix[j, i] = 1
return adjacency_matrix
效率很低
用PyTorch的广播机制,通过将边索引直接作为索引,一次性将对应的邻接矩阵元素设置为1,避免了使用for循环进行逐个元素的填充。这种方法在大规模图形上具有更高的效率。
import torch
def edge_index_to_adjacency_matrix(edge_index, num_nodes):
# 构建一个大小为 (num_nodes, num_nodes) 的零矩阵
adjacency_matrix = torch.zeros(num_nodes, num_nodes, dtype=torch.uint8)
# 使用索引广播机制,一次性将边索引映射到邻接矩阵的相应位置上
adjacency_matrix[edge_index[0], edge_index[1]] = 1
adjacency_matrix[edge_index[1], edge_index[0]] = 1
return adjacency_matrix