Graphormer代码解读-spatial pos

self.spatial_pos_encoder = nn.Embedding(num_spatial, num_heads, padding_idx=0)
# spatial pos
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]     
 spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)   
 graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias     

spatial_pos为节点到每个节点的跳数,数据原始维度为(graph,node,node)
nn.Embedding会将向量扩充一维,将跳数从one-hot到向量化

你可能感兴趣的:(机器学习,论文,图论,知识图谱,机器学习)