PyG教程(3):邻居采样

一.为什么需要邻居采样?

在GNN领域,大图是非常常见的,但由于GPU显存的限制,大图是无法放到GPU上进行训练的。为此,可以采用邻居采样,这样一来可以将GNN扩展到大图上。在PyG中,邻居采样的方式有很多种,具体详解torch_geometric.loader。本文以GraphSage中的邻居采样为例进行介绍,其在PyG中实现为NeighborLoader

NeighborSampler也是PyG中关于GraphSage中邻居采样的实现,但已经被弃用,在未来版本中会被删除。

二.NeighborLoader详解

2.1 GraphSage邻居采样原理

假设采样的层数为 K K K,每层采样的邻居数为 S k S_k Sk,GraphSage中邻居采样是这样进行的:

  • 步骤一:首先给定要采样邻居的小批量节点集 B \mathcal{B} B
  • 步骤二:对 B \mathcal{B} B 1 1 1跳(hop)邻居进行采样,然后得到 B 1 \mathcal{B}_1 B1,然后对 B 1 \mathcal{B}_1 B1 1 1 1跳邻居进行采样(即最初结点集的 2 2 2跳邻居)得到 B 2 \mathcal{B}_2 B2,如此往复进行 K K K次,得到最初小批量节点集相关的一个子图。

下图左是GraphSage中给出的一个2层邻居采样的示例,其中每层采样的邻居数 S k S_k Sk是相等的(图中为3)。

PyG教程(3):邻居采样_第1张图片

2.2 API介绍

PyG中,GraphSage的邻居采样实现为torch_geometric.loader.NeighborLoader,其初始化函数参数为:

def __init__(
    self,
    data: Union[Data, HeteroData],
    num_neighbors: NumNeighbors,
    input_nodes: InputNodes = None,
    replace: bool = False,
    directed: bool = True,
    transform: Callable = None,
    neighbor_sampler: Optional[NeighborSampler] = None,
    **kwargs,
)

常用参数说明如下:

  • data:要采样的图对象,可以为异构图HeteroData,也可以为同构图Data
  • num_neighbors:每个节点每次迭代(每层)采样的最大邻居数,List[int]类型,例如[2,2]表示采样2层,每层中每个节点最多采样2个邻居;
  • input_nodes:从原始图中采样得到的子图中需要包含的原始图中节点索引,即2.1节中最初的 B \mathcal{B} Btorch.Tensor()类型;
  • directed:如果设置为False,将包括所有采样节点之间的所有边;
  • **kwargstorch.utils.data.DataLoader的额外参数,例如batch_sizeshuffle(具体详见该API)。

2.3 采样实践

为了可视化的美观性,本小节采用的图数据是PyG中提供的KarateClub数据集,该数据集描述了一个空手道俱乐部会员的社交关系,节点为34名会员,如果两位会员在俱乐部之外仍保持社交关系,则在对应节点间连边,该数据集的可视化如下所示:

PyG教程(3):邻居采样_第2张图片

下面是对该数据集的加载、可视化以及邻居采样的源码:

import torch
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader


def draw(graph):
    nids = graph.n_id
    graph = to_networkx(graph)
    for i, nid in enumerate(nids):
        graph.nodes[i]['txt'] = str(nid.item())
    node_labels = nx.get_node_attributes(graph, 'txt')
    # print(node_labels)
    # {0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
    nx.draw_networkx(graph, labels=node_labels, node_color='#00BFFF')
    plt.axis("off")
    plt.show()


dataset = KarateClub()
g = dataset[0]
# print(g)
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
g.n_id = torch.arange(g.num_nodes)

for s in NeighborLoader(g, num_neighbors=[2, 2], input_nodes=torch.Tensor([14])):
    draw(s)
    break

在上述源码中,设置的采样层数为2层、每个节点每层采样最多采样2个邻居,采样的初始节点集为{14},其对应的采样结果如下所示:

PyG教程(3):邻居采样_第3张图片

从上图可以看出,在第一次迭代中,采样了节点{14}的两个1跳邻居{32,33},然后在第二次迭代中对{32,33}分别进行采样得到{2,8]}{18,30}

需要注意是通过NeighborLoader返回的子图中,全局节点索引会映射到到与该子图对应的局部索引。因此,若要将当前采样子图中的节点映射会原来图中对应的节点,可以在原始图中创建一个属性来完成两者之间的映射,例如采样实践源码中的:

g.n_id = torch.arange(g.num_nodes)

如此以来,采样后子图中的节点同样包含n_id属性,这样就可以将子图的节点映射回去了,上述示例中对图进行可视化便利用了这一点,其对应的映射为:

{0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}

结语

PyG中对于邻居采样的实现远远不止上述这一种,具体参见如下官网资料:

  • torch_geometric.loader

你可能感兴趣的:(图神经网络框架,深度学习,pytorch,人工智能,GNN)