PyG edge index 转换回 邻接矩阵

PyG的edge index形式是 [ ( n o d e 1 , n o d e 2 ) , ( n o d e 1 , n o d e 3 ) . . . ] [(node_1,node_2), (node_1, node_3)...] [(node1,node2),(node1,node3)...]这种edge pair。

naive

直接for循环,吧edge index里面的位置填充1:

import torch  
  
def edge_index_to_adjacency_matrix(edge_index, num_nodes):  
    # 创建大小为 (num_nodes, num_nodes) 的二维张量  
    adjacency_matrix = torch.zeros(num_nodes, num_nodes)  
      
    # 根据边索引填充邻接矩阵的元素  
    for i, j in zip(*edge_index):  
        adjacency_matrix[i, j] = 1  
        adjacency_matrix[j, i] = 1  
      
    return adjacency_matrix

效率很低

利用传播机制

用PyTorch的广播机制,通过将边索引直接作为索引,一次性将对应的邻接矩阵元素设置为1,避免了使用for循环进行逐个元素的填充。这种方法在大规模图形上具有更高的效率。

import torch  
  
def edge_index_to_adjacency_matrix(edge_index, num_nodes):  
    # 构建一个大小为 (num_nodes, num_nodes) 的零矩阵  
    adjacency_matrix = torch.zeros(num_nodes, num_nodes, dtype=torch.uint8)  
      
    # 使用索引广播机制,一次性将边索引映射到邻接矩阵的相应位置上  
    adjacency_matrix[edge_index[0], edge_index[1]] = 1  
    adjacency_matrix[edge_index[1], edge_index[0]] = 1  
      
    return adjacency_matrix

你可能感兴趣的:(pytorch,python,scipy,pytorch,邻接矩阵)