【图神经网络】学习聚合函数 GraphSAGE

本文为图神经网络学习笔记,讲解学习聚合函数 GraphSAGE。欢迎在评论区与我交流

前言

本教程在 PPI(蛋白质网络)数据集上用 Tensorflow 搭建 GraphSAGE 框架中的 MaxPooling 聚合模型实现有监督下的图节点标签预测任务。

GraphSAGE 简介

GraphSAGE 是一种在超大规模图上,利用节点的属性信息高效产生未知节点特征表示归纳式学习框架。GraphSAGE 可以被用来生成节点的低维向量表示,尤其对于具有丰富节点属性的 Graph 效果显著。

目前大多数的框架都是直推式学习模型,即只能够在一张固定的 Graph 上进行表示学习,这样既不能够对那些在训练中未见的节点进行有效的向量表示,也不能够跨图进行节点表示学习。GraphSAGE 作为一种归纳式的表示学习框架,能够利用节点丰富的属性信息有效地生成未知节点的特征表示。

【图神经网络】学习聚合函数 GraphSAGE_第1张图片

GraphSAGE的核心思想是通过学习一个对邻居节点进行聚合表示的函数,来产生中心节点的特征表示,而不是学习节点本身的 embedding。它既可以进行监督学习也可以进行无监督学习,GraphSAGE 中的聚合函数有以下几种:

  • Mean Aggregator

    Mean 聚合近似等价 GCN 中的卷积传播操作。具体来说就是对中心节点的邻居节点的特征向量求均值,然后和中心节点特征向量拼接,中间有两次非线性变换。

  • GCN Aggregator

    GCN的归纳式学习版本

  • Pooling Aggregator

    【图神经网络】学习聚合函数 GraphSAGE_第2张图片

    先对中心节点的邻居节点表示向量进行一次非线性变换,然后对变换后的邻居表示向量进行池化操作(mean pooling 或者 max pooling),最后将 pooling 所得结果与中心节点的特征表示分别进行非线性变换,并将所得结果进行拼接或者相加从而得到中心节点在该层的向量表示。

  • LSTM Aggregator

    将中心节点的邻居节点随机打乱作为输入序列,将所得向量表示与中心节点的向量表示分别经过非线性变换后拼接,得到中心节点在该层的向量表示。LSTM 本身用于序列数据,而邻居节点没有明显的序列关系,因此输入到 LSTM 中的邻居节点需要随机打乱顺序

以 MaxPooling 聚合方法为例构建 GraphSAGE 模型进行有监督学习下的分类任务。

PPI 数据集

PPI(Protein-protein interaction networks)数据集由 24 个对应人体不同组织的图组成。其中 20 个图用于训练,2 个图用于验证,2 个图用于测试。平均每张图有 2372 个节点,每个节点有 50 个特征。测试集中的图与训练集中的图没有交叉,即在训练阶段测试集中的图是不可见的。每个节点拥有多种标签,标签的种类总共有 121 种。

构建模型

我们使用的核心库是 tf_geometric,借助这个 GNN 库可以方便地导入数据集,预处理图数据以及搭建图神经网络。另外我们还引用了 tf.keras.layers 中的 Dropout 缓解过拟合,以及 sklearn 中的 micro f1_score 函数作为评价指标。

导入库函数:

# coding=utf-8
import os
import tensorflow as tf
from tensorflow import keras
import numpy as np
from tf_geometric.layers.conv.graph_sage import  MaxPoolingGraphSage
from tf_geometric.datasets.ppi import PPIDataset
from sklearn.metrics import f1_score
from tqdm import tqdm
from tf_geometric.utils.graph_utils import RandomNeighborSampler 

加载数据集,使用 tf_geometric自带的PPI数据集。 tf_geometric 提供了简单的图数据构建接口,只需要传入简单的 Python 数组或 Numpy 数组作为节点特征和邻居表就可以构建自己的数据集,如 GIN。

# 使用 tf_geometric 自带的 PPI 数据集,返回划分好的训练集(20),验证集(2),测试集(2)。
train_graphs, valid_graphs, test_graphs = PPIDataset().load_data()

由于每个节点的邻居节点的数目不一,出于计算效率的考虑,我们对每个节点采样一定数量的邻居节点作为之后聚合领域信息时的邻居节点。设定采样数量为 num_sample,如果邻居节点的数量大于 num_sample,采用无放回采样。如果邻居节点的数量小于 num_sample,采用有放回采样,直到所采样的邻居节点数量达到 num_sample

# traverse all graphs
for graph in train_graphs + valid_graphs + test_graphs:
  	# andomNeighborSampler 提前对每张图进行预处理,将相关的图信息与各自的图绑定
    neighbor_sampler = RandomNeighborSampler(graph.edge_index)
    # 模型可能会同时作用在多个图上,要保证每张图的邻居节点在抽样结束后不混淆
    # 将抽样结果与每个 Graph 对象绑定。即将抽样信息保存在“cache"缓存字典中
    graph.cache["sampler"] = neighbor_sampler

采用两层 MaxPooling 聚合函数来聚合 Graph 中邻居节点蕴含的信息:

# 邻居节点采样数目分别为 25 和 10
graph_sages = [
    MaxPoolingGraphSage(units=128, activation=tf.nn.relu),
    MaxPoolingGraphSage(units=128, activation=tf.nn.relu)
]

# 用 Sequential 快速创建神经网络
fc = tf.keras.Sequential([
    keras.layers.Dropout(0.3), # 使用 dropout
    tf.keras.layers.Dense(num_classes)
])

num_sampled_neighbors_list = [25, 10]

def forward(graph, training=False): # 前向传播
    neighbor_sampler = graph.cache["sampler"] # 从图 sampler 缓存中取出邻居样本
    h = graph.x
    for i, (graph_sage, num_sampled_neighbors) in enumerate(zip(graph_sages, num_sampled_neighbors_list)):
      	# 之前已经通过 `RandomNeighborSampler` 为每张图处理好相关的图结构信息
      	# sampled_edge_index 边
        # sampled_edge_weights 边权
        sampled_edge_index, sampled_edge_weight = neighbor_sampler.sample(k=num_sampled_neighbors) # 采样 num_sampled_neighbors 个样本
        # 将采样得到的边、边权、节点的特征向量输入到 GraphSAGE 模型
        # Dropout 层在训练和预测阶段的状态不同,通过参数 training 来控制
        h = graph_sage([h, sampled_edge_index, sampled_edge_weight], training=training)

    h = fc(h, training=training)

    return h

max_pooling_graph_sage 的具体实现

MaxPooling 聚合函数是一个带有 max-pooling 操作的单层神经网络。首先传递每个中心节点的邻居节点向量到一个非线性层中。由于 tf_geometric 基于边表结构进行相关 Graph 操作,所以先用 gather 转换得到所有节点的邻居节点的特征向量组成的特征矩阵:

row, col = edge_index
# x 是 Graph 中的节点特征矩阵
repeated_x = tf.gather(x, row) # row 是 Graph 中的源节点序列
neighbor_x = tf.gather(x, col) # col 是 Graph 中的目标节点序列

tf.gather 根据节点序列从节点特征矩阵中选取对应的节点特征,堆叠形成所有邻居节点组成的特征矩阵。tf.gather 具体操作如下:

得到加权后邻居节点特征向量:

neighbor_x = gcn_mapper(repeated_x, neighbor_x, edge_weight=edge_weight)

在进行 max-pooling 前将所有邻居节点的特征向量输入全连接网络,计算邻居节点的特征表示(将 MLP 看做是一组函数):

neighbor_x = dropout(neighbor_x) # 经过 dropout,计算特征表示
h = neighbor_x @ mlp_kernel
if mlp_bias is not None:
    h += mlp_bias

if activation is not None:
    h = activation(h)

对邻居节点特征向量进行 max-pooling ,然后将所得向量与经过变换的中心节点特征向量拼接输出。 理想的聚合方法应该是简单、可学习且对称的。换句话说,理想的 aggregator 应该学会如何聚合邻居节点的特征表示,并对邻居节点的顺序不敏感,同时不会造成巨大的训练开销:

reduced_h = max_reducer(h, row, num_nodes=len(x))
reduced_h = dropout(reduced_h) # 进行 dropout
x = dropout(x)

from_neighs = reduced_h @ neighs_kernel
from_x = x @ self_kernel
output = tf.concat([from_neighs, from_x], axis=1)
if bias is not None:
    output += bias

if activation is not None:
    output = activation(output)

if normalize:
    output = tf.nn.l2_normalize(output, axis=-1) # 进行归一化

return output

GraphSAGE 训练

模型的训练与其他基于 Tensorflow 框架的模型训练基本一致,主要步骤有定义优化器,计算误差与梯度,反向传播等。正向传播算法的伪代码:

【图神经网络】学习聚合函数 GraphSAGE_第3张图片

GraphSAGE 论文用模型在第 10 轮训练后的表现来评估模型,因此这里我们将 epoches 设置为 10:

# 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)

for epoch in tqdm(range(10)): # epoches = 10
    for graph in train_graphs:
        with tf.GradientTape() as tape:
            logits = forward(graph, training=True) # 前向传播,执行 Dropout
            # 预测阶段输入为 valid_graphs 或 test_graphs 时,training=False 不 Dropout
            loss = compute_loss(logits, tape.watched_variables()) # 计算误差

        vars = tape.watched_variables()
        grads = tape.gradient(loss, vars) # 计算梯度
        optimizer.apply_gradients(zip(grads, vars)) # 优化器进行优化

    if epoch % 1 == 0:
        valid_f1_mic = evaluate(valid_graphs)
        test_f1_mic = evaluate(test_graphs) # 测试
        print("epoch = {}\tloss = {}\tvalid_f1_micro = {}".format(epoch, loss, valid_f1_mic))
        print("epoch = {}\ttest_f1_micro = {}".format(epoch, test_f1_mic))

计算模型损失。由于 PPI 数据集中每个节点有多个标签,属于多标签、多分类任务,因此选用 sigmoid 交叉熵函数:

def compute_loss(logits, vars):
    losses = tf.nn.sigmoid_cross_entropy_with_logits(
        logits=logits, # logits 是模型对节点标签的预测结果
        labels=tf.convert_to_tensor(graph.y, dtype=tf.float32) # graph.y 是图节点的真实标签
    )
		
    kernel_vals = [var for var in vars if "kernel" in var.name]
    # 防止过拟合,对模型的参数使用 L2 正则化。
    l2_losses = [tf.nn.l2_loss(kernel_var) for kernel_var in kernel_vals]

    return tf.reduce_mean(losses) + tf.add_n(l2_losses) * 1e-5

GrapSAGE 评估

使用 F1 Score 评估 MaxPoolingGraphSAGE 聚合邻居节点信息进行分类任务的性能:

def evaluate(graphs):
    y_preds = []
    y_true = []

    for graph in graphs:
        y_true.append(graph.y)
        # 将测试集中的图(训练阶段不可见)输入到经训练的 MaxPoolingGraphSAGE,得到预测结果
        logits = forward(graph)
        y_preds.append(logits.numpy())

    # 预测结果与其对应的 labels 转换为一维数组
    y_pred = np.concatenate(y_preds, axis=0)
    y = np.concatenate(y_true, axis=0)
	 	
    # 输入到 sklearn 中的 f1_score 方法,得到 F1_Score
    mic = calc_f1(y, y_pred)

    return mic

运行结果

epoch = 1	loss = 0.5231980085372925	valid_f1_micro = 0.45228990047917433
epoch = 1	test_f1_micro = 0.45506719065662915
 27%|██▋       | 3/11 [01:11<03:12, 24.11s/it]epoch = 2	loss = 0.5082718729972839	valid_f1_micro = 0.4825462475136504
epoch = 2	test_f1_micro = 0.4882603340749235
epoch = 3	loss = 0.49998781085014343	valid_f1_micro = 0.4906942451215627
epoch = 3	test_f1_micro = 0.502555249743498
 45%|████▌     | 5/11 [01:55<02:16, 22.79s/it]epoch = 4	loss = 0.4901132583618164	valid_f1_micro = 0.5383310665693446
epoch = 4	test_f1_micro = 0.5478608072643453
epoch = 5	loss = 0.484283983707428	valid_f1_micro = 0.5455753374297568
epoch = 5	test_f1_micro = 0.5516753473281046
 64%|██████▎   | 7/11 [02:41<01:31, 22.95s/it]epoch = 6	loss = 0.4761819541454315	valid_f1_micro = 0.5417373280572828
epoch = 6	test_f1_micro = 0.5504290907273931
 73%|███████▎  | 8/11 [03:03<01:08, 22.71s/it]epoch = 7	loss = 0.46836230158805847	valid_f1_micro = 0.5720065995217665
epoch = 7	test_f1_micro = 0.5843164717276317
 82%|████████▏ | 9/11 [03:24<00:44, 22.34s/it]epoch = 8	loss = 0.4760943651199341	valid_f1_micro = 0.5752257074185534
epoch = 8	test_f1_micro = 0.5855495700393325
 91%|█████████ | 10/11 [03:47<00:22, 22.34s/it]epoch = 9	loss = 0.461212694644928	valid_f1_micro = 0.5812645586399496
epoch = 9	test_f1_micro = 0.5930584548044271
100%|██████████| 11/11 [04:08<00:00, 22.61s/it]
epoch = 10	loss = 0.4568028450012207	valid_f1_micro = 0.5833869662874881
epoch = 10	test_f1_micro = 0.5964539684054789

有帮助的话点个赞加关注吧

完整代码见【demo_graph_sage.py】,完整教程见【Tensorflow-GraphSAGE-Tutorial】,论文见【GraphSAGE】。

你可能感兴趣的:(#,深度学习,python,tensorflow,机器学习,人工智能,深度学习)