在GNN领域,大图是非常常见的,但由于GPU显存的限制,大图是无法放到GPU上进行训练的。为此,可以采用邻居采样,这样一来可以将GNN扩展到大图上。在PyG中,邻居采样的方式有很多种,具体详解torch_geometric.loader
。本文以GraphSage中的邻居采样为例进行介绍,其在PyG中实现为NeighborLoader
。
NeighborSampler
也是PyG中关于GraphSage中邻居采样的实现,但已经被弃用,在未来版本中会被删除。
NeighborLoader
详解假设采样的层数为 K K K,每层采样的邻居数为 S k S_k Sk,GraphSage中邻居采样是这样进行的:
下图左是GraphSage中给出的一个2层邻居采样的示例,其中每层采样的邻居数 S k S_k Sk是相等的(图中为3)。
PyG中,GraphSage的邻居采样实现为torch_geometric.loader.NeighborLoader
,其初始化函数参数为:
def __init__(
self,
data: Union[Data, HeteroData],
num_neighbors: NumNeighbors,
input_nodes: InputNodes = None,
replace: bool = False,
directed: bool = True,
transform: Callable = None,
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
)
常用参数说明如下:
data
:要采样的图对象,可以为异构图HeteroData
,也可以为同构图Data
;num_neighbors
:每个节点每次迭代(每层)采样的最大邻居数,List[int]
类型,例如[2,2]
表示采样2层,每层中每个节点最多采样2个邻居;input_nodes
:从原始图中采样得到的子图中需要包含的原始图中节点索引,即2.1节中最初的 B \mathcal{B} B,torch.Tensor()
类型;directed
:如果设置为False
,将包括所有采样节点之间的所有边;**kwargs
:torch.utils.data.DataLoader
的额外参数,例如batch_size
,shuffle
(具体详见该API)。为了可视化的美观性,本小节采用的图数据是PyG中提供的KarateClub
数据集,该数据集描述了一个空手道俱乐部会员的社交关系,节点为34名会员,如果两位会员在俱乐部之外仍保持社交关系,则在对应节点间连边,该数据集的可视化如下所示:
下面是对该数据集的加载、可视化以及邻居采样的源码:
import torch
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader
def draw(graph):
nids = graph.n_id
graph = to_networkx(graph)
for i, nid in enumerate(nids):
graph.nodes[i]['txt'] = str(nid.item())
node_labels = nx.get_node_attributes(graph, 'txt')
# print(node_labels)
# {0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
nx.draw_networkx(graph, labels=node_labels, node_color='#00BFFF')
plt.axis("off")
plt.show()
dataset = KarateClub()
g = dataset[0]
# print(g)
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
g.n_id = torch.arange(g.num_nodes)
for s in NeighborLoader(g, num_neighbors=[2, 2], input_nodes=torch.Tensor([14])):
draw(s)
break
在上述源码中,设置的采样层数为2层、每个节点每层采样最多采样2个邻居,采样的初始节点集为{14}
,其对应的采样结果如下所示:
从上图可以看出,在第一次迭代中,采样了节点{14}
的两个1跳邻居{32,33}
,然后在第二次迭代中对{32,33}
分别进行采样得到{2,8]}
和{18,30}
。
需要注意是通过NeighborLoader
返回的子图中,全局节点索引会映射到到与该子图对应的局部索引。因此,若要将当前采样子图中的节点映射会原来图中对应的节点,可以在原始图中创建一个属性来完成两者之间的映射,例如采样实践源码中的:
g.n_id = torch.arange(g.num_nodes)
如此以来,采样后子图中的节点同样包含n_id
属性,这样就可以将子图的节点映射回去了,上述示例中对图进行可视化便利用了这一点,其对应的映射为:
{0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
PyG中对于邻居采样的实现远远不止上述这一种,具体参见如下官网资料: