pyg的NeighborLoader和LinkNeighborLoader

NeighborLoader

1 数据格式要求

需要传入加载的属性值:

class NeighborLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], 
num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], 
input_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, 
input_time: Optional[Tensor] = None, 
replace: bool = False, 
directed: bool = True, 
disjoint: bool = False, 
temporal_strategy: str = 'uniform', 
time_attr: Optional[str] = None, 
transform: Optional[Callable] = None, 
transform_sampler_output: Optional[Callable] = None, 
is_sorted: bool = False, 
filter_per_worker: bool = False, 
neighbor_sampler: Optional[NeighborSampler] = None, **kwargs)

        data: 要求加载 torch_geometric.data.Data 或者 torch_geometric.data.HeteroData 类型数据;

        num_neighbors: 每轮迭代要采样邻居节点的个数,即第i-1轮要为每个节点采样num_neighbors[i]个节点,如果为-1,则代表所有邻居节点都将被包含(一阶相邻邻居),在异构图中,还可以使用字典来表示每个单独的边缘类型要采样的邻居数量;

        input_nodes : 中心节点集合,用来指导采样一个mini-batch内的节点,如果为None,则代表包含data中的所有节点。如果设置为 None,将考虑所有节点。在异构图中,需要作为包含节点类型和节点索引的元组传递。 (默认值:None)

        input_time (torch.Tensor, optional) – 可选值,用于覆盖 input_nodes 中给定的输入节点的时间戳。如果未设置,将使用 time_attr 中的时间戳作为默认值(如果存在)。需要设置 time_attr 才能使其工作。 (默认值:None)

        replace (bool, optional) – 如果设置为 True,将进行替换采样。 (默认值:False)

        directed (bool, optional) – 如果设置为 False,将包括所有采样节点之间的所有边。 (默认值:True)

        disjoint (bool, optional) – 如果设置为 :obj: True,每个种子节点将创建自己的不相交子图。如果设置为 True,小批量输出将有一个批量向量保存节点到它们各自子图的映射。在时间采样的情况下将自动设置为 True。 (默认值:False) 

        temporal_strategy (str, optional) -- 使用时间采样时的采样策略(“uniform”、“last”)。如果设置为“uniform”,将在满足时间约束的邻居之间统一采样。如果设置为“last”,将对满足时间约束的最后 num_neighbors 进行采样。 (默认值:“uniform”)

         transform (callable, optional) – 一个函数/转换,它接受一个采样的小批量并返回一个转换后的版本。 (默认值:None)

        transform_sampler_output (callable, optional) – 接受 SamplerOutput 并返回转换后版本的函数/转换。 (默认值:无)

        **kwargs(可选)—— torch.utils.data.DataLoader 的附加参数,例如 batch_size、shuffle、drop_last 或 num_workers。

2 上述参数使用案例:

(1)当 num_neighbors = [-1]时,获取中心节点所有的一阶邻居;

        batch_size=1,表示中心节点只有一个; 

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
import torch
import networkx as nx
import matplotlib.pyplot as plt

data = Planetoid('./dataset', name='Cora')[0]

loader_2 = NeighborLoader(
    data,
    num_neighbors=[-1],
    batch_size=1,
    input_nodes=data.n_id,
)
# 准备边数据
sampled_data_2 = next(iter(loader_2))
# sampled_data_2 输出格式:
# Data(x=[4, 1433], edge_index=[2, 3], y=[4], train_mask=[4], val_mask=[4], test_mask=[4], n_id=[4], batch_size=1)
edge_2 = np.array(sampled_data_2.edge_index).T
edge_2 = edge_2.tolist()
edge_2 = list(tuple(line) for line in edge_2)

# 画图展示
G_2 = nx.Graph()
G_2.add_edges_from(edge_2)
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
option = {'font_family':'serif', 'font_size':'15', 'font_weight':'semibold'}
nx.draw_networkx(G_2, node_size=400, **option)
plt.show()

        画图展示:

pyg的NeighborLoader和LinkNeighborLoader_第1张图片

         代码中的sampled_data_2中的涉及节点的输出:

sampled_data_2.n_id

# tensor([   0,  633, 1862, 2582])
# 前batch_size个节点为中心节点

(2)当 num_neighbors = [2,3]时,获取中心节点所有的一阶邻居(任选取3个节点)以及一阶邻居的邻居(任选取两个节点);

代码展示:

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
import torch
import networkx as nx
import matplotlib.pyplot as plt

data = Planetoid('./dataset', name='Cora')[0]
data.n_id = torch.arange(data.num_nodes)

loader_2 = NeighborLoader(
    data,
    num_neighbors=[2,3],
    batch_size=3,
    input_nodes=data.n_id,
)
# 准备边数据
sampled_data_2 = next(iter(loader_2))
# sampled_data_2 输出格式:
# Data(x=[11, 1433], edge_index=[2, 14], y=[11], train_mask=[11], val_mask=[11], test_mask=[11], n_id=[11], batch_size=3)
edge_2 = np.array(sampled_data_2.edge_index).T
edge_2 = edge_2.tolist()
edge_2 = list(tuple(line) for line in edge_2)

# 画图展示
G_2 = nx.Graph()
G_2.add_edges_from(edge_2)
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
option = {'font_family':'serif', 'font_size':'15', 'font_weight':'semibold'}
nx.draw_networkx(G_2, node_size=400, **option)
plt.show()

pyg的NeighborLoader和LinkNeighborLoader_第2张图片

         代码中的sampled_data_2中的涉及节点的输出:

sampled_data_2.n_id

# tensor([   0,    1,    2,  633, 2582,  654, 1454, 1701, 1866, 1166, 1862])
# 前batch_size个节点为中心节点

3 获得子图的id的映射

        当实际应用中我们要获取训练集和测试集的子图,因此一般输入在NeighborLoader的input_nodes参数的值对应于训练集的id和测试集的id;

        而获得的边对应的id不是实际大图中的节点id,而是后来按照顺序分配的;

例如:

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader,LinkNeighborLoader
import torch
import networkx as nx
import matplotlib.pyplot as plt

data = Planetoid('./dataset', name='Cora')[0]
data.n_id = torch.arange(data.num_nodes)
test_id = torch.tensor([i for i in range(100,120)])

loader_2 = NeighborLoader(
    data,
    num_neighbors=[2,3],
    batch_size=3,
    input_nodes=test_id,
)
# 准备边数据
sampled_data_2 = next(iter(loader_2))
# sampled_data_2 输出格式:
# 
edge_2 = np.array(sampled_data_2.edge_index).T
edge_2 = edge_2.tolist()
edge_2 = list(tuple(line) for line in edge_2)

# 画图展示
G_2 = nx.Graph()
G_2.add_edges_from(edge_2)
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
option = {'font_family':'serif', 'font_size':'15', 'font_weight':'semibold'}
nx.draw_networkx(G_2, node_size=400, **option)
plt.show()

pyg的NeighborLoader和LinkNeighborLoader_第3张图片

 

print(sampled_data_2.edge_index)
print(sampled_data_2.n_id)
print(sampled_data_2.num_nodes)
# 输出
tensor([[ 3,  4,  5,  6,  7,  8,  9, 10, 11, 12,  0, 13, 14,  1, 15,  1, 16, 17,
          2,  8, 18, 19, 20],
        [ 0,  0,  1,  1,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,  6,  6,  6,
          7,  7,  8,  8,  8]])
tensor([ 100,  101,  102, 1602, 2056,  281, 1589, 1561, 1623,   95,  315, 2073,
         734, 1628, 1347, 1382, 1745, 2596, 1769, 1772, 1771])
21

 将图进行可视化时,可以映射回大图中的id

2 LinkNeighborLoader

1 数据格式要求

需要传入加载的属性值:

class LinkNeighborLoader(data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], 
num_neighbors: Union[List[int], Dict[Tuple[str, str, str], List[int]]], 
edge_label_index: Union[Tensor, None, Tuple[str, str, str], Tuple[Tuple[str, str, str], Optional[Tensor]]] = None, 
edge_label: Optional[Tensor] = None, 
edge_label_time: Optional[Tensor] = None, replace: bool = False, 
directed: bool = True, disjoint: bool = False, 
temporal_strategy: str = 'uniform', 
neg_sampling: Optional[NegativeSampling] = None, 
neg_sampling_ratio: Optional[Union[int, float]] = None, 
time_attr: Optional[str] = None, 
transform: Optional[Callable] = None, 
transform_sampler_output: Optional[Callable] = None, 
is_sorted: bool = False, 
filter_per_worker: bool = False, 
neighbor_sampler: Optional[NeighborSampler] = None, **kwargs)

        作为基于节点的 torch_geometric.loader.NeighborLoader 的扩展派生的基于链接的数据加载器。该加载器允许在无法进行整批训练的大规模图上对 GNN 进行小批量训练

        更具体地说,这个加载器首先从输入边 edge_label_index 集合中选择一个边样本(它可能是原始图中的边,也可能不是原始图中的边),然后通过在每次迭代中采样 num_neighbors 个邻居,从这个列表中存在的所有节点构造一个子图.

你可能感兴趣的:(编程,GNN,python,算法,开发语言)