CS224W图机器学习笔记自用:GNN Augmentation and Training

Recap:

CS224W图机器学习笔记自用:GNN Augmentation and Training_第1张图片
today’s outline:

  • (4)Graph augmentation
  • (5)Learning objective

1. GNN 的图增强(Graph Augmentation for GNNs)

两种图增强的方法:

  • 图特征增强
  • 图结构增强
    CS224W图机器学习笔记自用:GNN Augmentation and Training_第2张图片

1.1 为什么要增强图

需要增强图的原因

  • 特征(Features)
    • 输入图缺乏特征
  • 图结构(Graph structure)
    • 图太稀疏 -> 消息传递效率低下
    • 图太密集 -> 消息传递成本太高
    • 图太 -> 无法将计算图存入GPU中

综上所述,输入图不是嵌入的最佳计算图 。

1.2 图增强的方法

  • 图特征增强
    • 输入图缺乏特征 -> 特征增强
  • 图结构增强
    • 图太稀疏 -> 添加虚拟节点或边
    • 图太密集 -> 消息传递时只采样部分邻居节点进行传递
    • 图太 -> 计算嵌入时对子图进行采样

1.2.1 图特征增强

为什么我们需要特征增强?

  • 第一种情况

    • 输入图没有节点特征,我们只有这个图的邻接矩阵,此时需要进行图特征增强。
  • 解决方法:

    • a)为节点分配常数特征CS224W图机器学习笔记自用:GNN Augmentation and Training_第3张图片

    • b)为节点分配唯一的ID,这些ID值可以被转换为独热向量CS224W图机器学习笔记自用:GNN Augmentation and Training_第4张图片
      两种方法的比较:Constant vs. one-hot
      CS224W图机器学习笔记自用:GNN Augmentation and Training_第5张图片
      第二种方法比第一种方法表达能力更强;第一种方法归纳能力更强,能够很容易地推广到新节点,第二种方法则不行;同时第一种方法的计算开销也更小;第一种方法适用于任意图,同时具有归约能力,能推广到新节点,第二种方法适用于小图,只适用于transductive setting,不适合inductive setting。

  • 第二种情况

    • GNN 很难学习某些特殊结构
    • 例如:环节点数特征(Cycle count feature)
      • GNN无法学习 v 1 v_1 v1所在环的长度,也无法区分 v 1 v_1 v1所在的是哪个图形CS224W图机器学习笔记自用:GNN Augmentation and Training_第6张图片
      • 因为这两张图中的所有节点度数都为2
      • 计算图也是完全相同的二叉树CS224W图机器学习笔记自用:GNN Augmentation and Training_第7张图片
  • 解决方法:我们可以使用循环计数作为增强的节点特征
    CS224W图机器学习笔记自用:GNN Augmentation and Training_第8张图片

其他常用的增强特征有:

  1. 节点度数
  2. 聚类系数(Clustering coefficient)
  3. PageRank
  4. Centrality
    我们在第二节所提到的节点特征都可以使用。

1.2.2 图结构增强

  1. 针对图稀疏添加虚拟节点或边

    1. 添加虚拟边
      常用方法:通过虚拟边连接 2 跳邻居
      想法:用 A + A 2 A + A^2 A+A2代替邻接矩阵 A A A进行GNN的计算
      例子:二部图
      • Author-to-papers (他们撰写的)
      • 2 跳虚拟边构成作者-作者协作图CS224W图机器学习笔记自用:GNN Augmentation and Training_第9张图片
    2. 添加虚拟节点虚拟节点将连接到图中的所有节点
      • 假设在一个稀疏图中,两个节点的最短路径距离为 10
      • 添加虚拟节点后,所有节点的距离为 2
        • Node A - Virtual node - Node B
      • 好处:大大提高了稀疏图中的消息传递CS224W图机器学习笔记自用:GNN Augmentation and Training_第10张图片
  2. 针对图密集的问题:节点邻域采样

    思想:在之前的设计中,所有节点参与消息传递,现在,我们**(随机)对节点的邻域进行采样以进行消息传递**,以解决图密集的问题。
    例子:例如,我们可以随机选择 2 个邻居在给定层中传递消息CS224W图机器学习笔记自用:GNN Augmentation and Training_第11张图片
    下一层,当我们计算嵌入时,我们可以采样不同的邻居(对于类似于社交网络的图,也可以仅采样一些重要的节点,不必采样那些不重要的节点)
    CS224W图机器学习笔记自用:GNN Augmentation and Training_第12张图片
    CS224W图机器学习笔记自用:GNN Augmentation and Training_第13张图片

    最后在预期中,我们得到类似于使用所有邻居的情况的嵌入。
    这种方法的好处:可以大大降低计算成本,并且允许scaling to 大图,在实践中的效果也很好。

2. Training with GNNs

CS224W图机器学习笔记自用:GNN Augmentation and Training_第14张图片
Learning so far:
CS224W图机器学习笔记自用:GNN Augmentation and Training_第15张图片

2.1 Prediction head:如何从节点嵌入到实际预测

CS224W图机器学习笔记自用:GNN Augmentation and Training_第16张图片
**预测头(prediction head)**有以下几种类型:

  • 节点级任务
  • 边级别任务
  • 图级别任务

不同的任务级别需要不同的预测头CS224W图机器学习笔记自用:GNN Augmentation and Training_第17张图片

2.1.1 节点级预测头

1. 节点级预测:我们可以直接使用节点嵌入进行预测

  • 在 GNN 计算之后,我们有d维的节点嵌入 { h v ( L ) ∈ R d , ∀ v ∈ G } \{ h_v^{(L)} \in R^d,\forall v \in G \} {hv(L)Rd,vG}
  • 假设我们要进行一个k类别的预测
    • 分类问题:在k个类别中分类
    • 回归问题:回归k个目标
  • y ^ v = H e a d n o d e ( h v ( L ) ) = W ( H ) h v ( L ) \hat{y}_v = Head_{node}(h_v^{(L)})=W^{(H)}h_v^{(L)} y^v=Headnode(hv(L))=W(H)hv(L)
    • W ( H ) ∈ R k × d W^{(H)} \in R^{k \times d} W(H)Rk×d: 我们映射节点嵌入从 h v ( L ) ∈ R d h_v^{(L)} \in ℝ^d hv(L)Rd y ^ v ∈ R k \hat{y}_v \in ℝ^k y^vRk,这样我们就可以计算损失

2.1.2 边级别预测头

2. 边级别预测:使用节点嵌入对进行预测

  • 假设我们要进行一个k类别的预测
  • y ^ u v = H e a d e d g e ( h u ( L ) , h v ( L ) ) \hat{y}_{uv} = Head_{edge}(h_u^{(L)},h_v^{(L)}) y^uv=Headedge(hu(L),hv(L))
    CS224W图机器学习笔记自用:GNN Augmentation and Training_第18张图片
  • H e a d n o d e ( h v ( L ) ) = W ( H ) h v ( L ) Head_{node}(h_v^{(L)})=W^{(H)}h_v^{(L)} Headnode(hv(L))=W(H)hv(L)有多种选择
    • (1) 串联 + 线性
      • 在图注意力网络也有类似的架构CS224W图机器学习笔记自用:GNN Augmentation and Training_第19张图片
      • y ^ u v = L i n e a r ( C o n c a t ( h u ( L ) , h v ( L ) ) ) \hat{y}_{uv} = Linear(Concat(h_u^{(L)},h_v^{(L)})) y^uv=Linear(Concat(hu(L),hv(L)))
      • 这里线性映射函数Linear(.)会把2d维的嵌入向量映射到k维(k个类别)的嵌入中
    • (2)点积
      • y ^ u v = ( h u ( L ) ) T h v ( L ) \hat{y}_{uv} = (h_u^{(L)})^T h_v^{(L)} y^uv=(hu(L))Thv(L)
      • 这种方法仅适用于 1-way 预测(例如,链接预测:预测边缘的存在
      • 应用到 k-way 预测上,类似于多头注意力机制 W ( 1 ) , . . . , W ( k ) W^{(1)},... ,W^{(k)} W(1)...,W(k)是可训练的参数CS224W图机器学习笔记自用:GNN Augmentation and Training_第20张图片

2.1.3 图级别预测头

3. 图级别预测使用图中的所有节点嵌入进行预测

  • 假设我们要进行一个k类别的预测
  • y ^ G = H e a d g r a p h ( { h v ( L ) ∈ R d , ∀ v ∈ G } \hat{y}_G = Head_{graph}(\{h_v^{(L)} \in R^d, \forall v \in G\} y^G=Headgraph({hv(L)Rd,vG}
  • H e a d g r a p h ( ⋅ ) Head_{graph}(\cdot) Headgraph()类似于 GNN 层中的聚合函数 A G G ( ⋅ ) AGG(\cdot) AGG()CS224W图机器学习笔记自用:GNN Augmentation and Training_第21张图片
  • H e a d g r a p h ( { h v ( L ) ∈ R d , ∀ v ∈ G } Head_{graph}(\{h_v^{(L)} \in R^d, \forall v \in G\} Headgraph({hv(L)Rd,vG}有多种选择:
    • 全局平均池化层:与节点数无关,mean pooling可用于比较大小相差很大的图形
      • y ^ G = M e a n ( { h v ( L ) ∈ R d , ∀ v ∈ G } ) \hat{y}_G = Mean(\{h_v^{(L)} \in R^d, \forall v \in G\}) y^G=Mean({hv(L)Rd,vG})
    • 全局最大池化层
      • y ^ G = M a x ( { h v ( L ) ∈ R d , ∀ v ∈ G } ) \hat{y}_G = Max(\{h_v^{(L)} \in R^d, \forall v \in G\}) y^G=Max({hv(L)Rd,vG})
    • 全局求和池化层:max pooling可以发现图中的节点数和图的结构
      • y ^ G = S u m ( { h v ( L ) ∈ R d , ∀ v ∈ G } ) \hat{y}_G = Sum(\{h_v^{(L)} \in R^d, \forall v \in G\}) y^G=Sum({hv(L)Rd,vG})
  • 全局池化层的问题:以上选项的全局池化层都只适用于小规模的图形,在大图上应用全局池化层会有信息丢失的问题,例如CS224W图机器学习笔记自用:GNN Augmentation and Training_第22张图片
  • 解决方法分层全局池化(分层聚合所有节点嵌入)
    • example:先聚合前两个节点,在聚合后两个节点CS224W图机器学习笔记自用:GNN Augmentation and Training_第23张图片
      现在,我们就能区分图1和图2。

那我们如何分层呢?

  • DiffPool:
    • 分层池化节点嵌入:利用图的社区结构,如果我们可以提前发现这些社区,那么我们就可以把每个社区当作一层聚合社区内的节点信息,接着我们可以进一步将社区嵌入汇总到超级社区嵌入。如下图,输入图用社区检测图分区算法分成了5个簇,这里用不同颜色表示,接着我们再汇总社区内的信息为每个社区生成一个超级节点,之后我们根据社区之间的联系再进行分簇,聚合,得到另一个超节点并不断聚合直到得到一个超级节点为止,然后就可以将其输入到预测头中:CS224W图机器学习笔记自用:GNN Augmentation and Training_第24张图片
    • Ying 等人(2018)提出的DiffPool在每个级别利用 2 个独立的 GNN
      • GNN A:计算节点嵌入
      • GNN B:进行图分区,判断节点所属的集群
    • 每个级别的 GNN A 和 B 可以并行执行
    • 对于每个池化层
      • 根据 GNN B 的聚类社区分配结果来聚合由 GNN A 生成的节点嵌入
      • 为每个集群创建一个新节点,维护集群之间的边以生成新的池化网络
    • 联合训练 GNN A 和 GNN B

2.2 Predictions and Labels

CS224W图机器学习笔记自用:GNN Augmentation and Training_第25张图片

2.2.1 监督学习 VS. 无监督学习

  1. 监督学习 VS. 无监督学习
    • 图上的监督学习标签来自外部来源,例如,预测分子图的药物相似性
    • 图上的无监督学习信号来自图本身,例如,链接预测:预测两个节点是否连接
    • 有时这些差异的界限是模糊的:我们在无监督学习中仍然有“监督”,例如,训练一个 GNN 来预测节点聚类系数,“无监督”也被称为“自我监督”
  2. 图上的监督标签
    • 监督标签来自特定的用例,例如:
      • 节点标签 y v y_v yv:在引文网络中,节点标签是节点属于哪个学科领域
      • 边的标签 y u v y_{uv} yuv:在交易网络中,边的标签是边是否具有欺诈性
      • 图的标签 y G y_G yG:分子图中,图标签是图的药物相似度
        -Advice:将您的任务减少到节点/边/图形标签,这样我们就能使用现有的框架。
  3. 图上的非监督标签
    • Problem:有时我们只有一个图,没有任何外部标签
    • 解决方法:自我监督学习,我们可以在图中找到监督信号
    • 以下任务不需要任何外部标签
      • 节点级别 y v y_v yv节点统计(如聚类系数、PageRank、…)或预测节点的属性
      • 边级别 y u v y_{uv} yuv:**链接预测(**隐藏两个节点之间的边,预测是否应该有链接)
      • 图级别 y G y_{G} yG图统计(例如,预测两个图是否同构)

2.3 Loss Function

CS224W图机器学习笔记自用:GNN Augmentation and Training_第26张图片
CS224W图机器学习笔记自用:GNN Augmentation and Training_第27张图片

2.3.1 分类 VS. 回归

  • 分类(Classification):节点的标签 y ( i ) y^{(i)} y(i)具有离散值
    • 例如,节点分类:节点属于哪个类别
  • 回归(Regression):节点的标签 y ( i ) y^{(i)} y(i)具有连续值
    • 例如,预测分子图的药物相似性或毒性水平
  • GNNs可以应用于这两类问题, 不同的在于损失函数评估指标

2.3.2 分类问题损失函数

  • 交叉熵 (cross entropy CE) 是分类中非常常见的损失函数
  • 我们要预测第i个数据点的类别(一共有K类)CS224W图机器学习笔记自用:GNN Augmentation and Training_第28张图片
  • 其它类型损失函数: H i n g e L o s s (铰链损失) Hinge Loss(铰链损失) HingeLoss(铰链损失)
    • 在"maximum-margin"的分类任务中,如支持向量机,表示预测输出,通常都是软结果(输出不是0,1这种,可能是0.87), 表示正确的类别,我们用下式作为分类函数:
      H i n g e L o s s = m a x ( 0 , m − y ^ y ) Hinge Loss = max(0, m - \hat{y}y) HingeLoss=max(0,my^y)
    • 很多时候我们希望训练的是两个样本之间的相似关系,而非样本的整体分类,所以很多时候我们会用下面的公式:
      H i n g e L o s s = m a x ( 0 , m − y + y ^ ) Hinge Loss = max(0, m - y + \hat{y}) HingeLoss=max(0,my+y^)
    • 其中,是y正样本的得分,是 y ^ \hat{y} y^负样本的得分,m是margin,即我们希望正样本分数越高越好,负样本分数越低越好,但二者得分之差最多到m就足够了,差距增大并不会有任何奖励。

2.3.3 回归问题损失函数

  • 对于回归任务,我们经常使用均方误差 (MSE) 也就是 L2 损失
  • 数据点i的k-way回归CS224W图机器学习笔记自用:GNN Augmentation and Training_第29张图片

2.4 Evaluation metrics

CS224W图机器学习笔记自用:GNN Augmentation and Training_第30张图片
在回归问题上,我们使用 GNN 的标准评估指标,在实践中我们通常使用sklearn程序包来实现,假设我们对 N 个数据点进行预测

2.4.1 回归问题分类指标

在图上评估回归任务,我们可以使用根均方差(RMSE)平均绝对误差(MAE) 这两个指标来评价:CS224W图机器学习笔记自用:GNN Augmentation and Training_第31张图片

2.4.2 分类问题分类指标

在图上评估分类任务:

  • (1) 多类分类
    • 只报告准确性CS224W图机器学习笔记自用:GNN Augmentation and Training_第32张图片
  • (2)二类分类
    • 对分类阈值敏感的指标

      • Accuracy (准确率)
      • Precision(精确率) / Recall (召回率)
      • 如果预测的范围是 [0,1],我们将使用 0.5 作为阈值 CS224W图机器学习笔记自用:GNN Augmentation and Training_第33张图片
    • 与分类阈值无关的指标

      • ROC Curve:捕获 TPR 和 FPR 的权衡,因为二元分类器的分类阈值是变化的(虚线表示随机分类器的性能)CS224W图机器学习笔记自用:GNN Augmentation and Training_第34张图片
      • ROC AUC:RUC曲线下的面积(Area under the ROC Curve),是分类器将随机选择的正实例得分高于随机选择的负实例的概率

3. 数据集拆分(训练/验证/测试集)

CS224W图机器学习笔记自用:GNN Augmentation and Training_第35张图片

3.1 常规拆分方案

  1. 固定拆分:我们将一次性分割我们的数据集
    • 训练集:用于优化 GNN 参数
    • 验证集:用于调整超参数和各种常数及决策选择
    • 测试集:只用于评估模型的最终性能
      我们用训练集和验证集确定最终模型,然后将模型应用到测试集
  2. 随机拆分:我们将数据集随机拆分为训练/验证/测试
    • 我们报告了不同随机种子的拆分方案平均性能

3.2 图的数据集拆分方案

  • 拆分图和拆分一般的数据集不一样,会造成数据泄露的问题
    • 在文档数据集或图像数据集中,我们拆分数据集时,假设数据点之间相互独立,这样很容易将其拆分成三个数据集,并且没有数据泄漏CS224W图机器学习笔记自用:GNN Augmentation and Training_第36张图片
    • 然而拆分图数据集是不一样的,图的问题在于节点之间相互连接,不是相互独立的,节点会从其他节点收集信息,这样会造成信息泄露的问题。CS224W图机器学习笔记自用:GNN Augmentation and Training_第37张图片
  • 解决方案 1(Transductive setting)只拆分节点标签,保持图的结构不变,整个输入图在所有数据集中都是可见的(即使用整个图计算嵌入)
    • 只拆分(节点)标签
      • 在训练时,我们使用整个图计算嵌入,并使用节点 1 和 2 的标签进行训练
      • 在验证时,我们使用整个图计算嵌入,并评估节点 3 和 4 的标签CS224W图机器学习笔记自用:GNN Augmentation and Training_第38张图片
  • 解决方案 2(Inductive setting)删除拆分出的数据集之间连接的边
    • 现在我们有 3 个独立的图。节点 5 将不再影响我们对节点 1 的预测
    • 在训练时,我们使用节点 1&2 上的图计算嵌入,并使用节点 1&2 的标签进行训练
    • 在验证时,我们使用节点 3&4 上的图计算嵌入,并评估节点 3&4 的标签
  • 两种方案的比较:Transductive / Inductive Settings
    • Transductive Settings: 训练/验证/测试集在同一张图上
      • 数据集由单个图组成
      • 可以在所有数据集拆分中观察到整个图,只拆分标签
      • 仅适用于节点/边预测任务
  • Inductive Settings:训练/验证/测试集在不同的图表上
    • 数据集由多个图组成
    • 每个拆分只能观察拆分内的图,这使我们能够真正测试如何将其推广到看不见的图形,一个成功的模型应该泛化到看不见的图
    • 适用于节点/边/图任务

3.3 图的拆分示例

  1. 节点分类
    CS224W图机器学习笔记自用:GNN Augmentation and Training_第39张图片
  2. 图分类
    在图分类问题中,由于我们分类独立的图,因此归纳设置不需要删除边就能应用,我们可以方便地将其分为训练、验证和测试集。
    CS224W图机器学习笔记自用:GNN Augmentation and Training_第40张图片
  3. 链接预测
    • 链接预测设置是图机器学习中最棘手的任务:它是一项无监督/自我监督的任务,我们需要自己创建标签和数据集拆分
    • 具体来说,我们需要对 GNN 隐藏一些边,并让 GNN 预测这些边是否存在CS224W图机器学习笔记自用:GNN Augmentation and Training_第41张图片
    • 对于链接预测,我们将两次分割边
    • 第 1 步在原始图中分配 2 种类型的边
      • 消息边:用于 GNN中的 消息传递
      • 监督边:用于计算目标
      • 第一步之后
        • 图中仅保留消息边,移除监督边
        • 监督边用作模型对边预测的监督,不会被输入 GNNCS224W图机器学习笔记自用:GNN Augmentation and Training_第42张图片
      • 第 2 步将边拆分为训练/验证/测试
      • 选项 1:归纳链接预测拆分
        • 假设我们有一个包含 3 个图的数据集。每个归纳拆分将包含一个独立的图CS224W图机器学习笔记自用:GNN Augmentation and Training_第43张图片
        • 假设我们有一个包含 3 个图的数据集。每个归纳分裂将包含一个独立的图
        • 在训练或验证或测试集中,每个图将有 2 种类型的边:消息边 + 监督边(监督边不是 GNN 的输入)CS224W图机器学习笔记自用:GNN Augmentation and Training_第44张图片
      • 选项2:Transductive链路预测分割(默认选项)
        • 根据“转导”的定义,可以在所有数据集拆分中观察到整个图
        • 训练时:使用训练消息边预测训练监督边
        • 验证时:使用训练消息边和训练监督边预测验证边
        • 测试时:使用训练消息边、训练监督边和验证边 预测 测试边CS224W图机器学习笔记自用:GNN Augmentation and Training_第45张图片 - Transductive链路预测分割将图的边分为四类:训练消息边、训练监督边、验证边、测试边,链接预测设置既棘手又复杂,您可能会发现论文以不同的方式进行链接预测。幸运的是,我们完全支持 PyG 和 GraphGym来帮助我们进行链接预测。

4. Summary: GNN Training Pipeline

CS224W图机器学习笔记自用:GNN Augmentation and Training_第46张图片
实现资源:

  • DeepSNAP 为该管道提供核心模块
  • GraphGym 进一步实现全流水线以方便 GNN 设计
    CS224W图机器学习笔记自用:GNN Augmentation and Training_第47张图片

5. Tips:When Things Don’t Go As Planned

5.1 通用提示

  • 数据预处理很重要
    • 节点属性的变化范围很大,从(0,1)到(-1000,1000)都有可能
    • 因此需要进行标准化
  • 优化器的选择
    • ADAM 对学习率相对稳健
  • 激活函数
    • ReLU 激活函数通常效果很好
    • 其他替代方案:LeakyReLU、SWISH、rational activation
    • 输出层没有激活函数
  • 在每一层中包含偏置项
  • 嵌入维度:32、64 和 128 通常是很好的起点

5.2 调试深度网络

调试问题损失/准确性在训练期间未收敛

  • 检查管道(例如在 PyTorch 中我们需要 zero_grad)
  • 调整学习率等超参数
  • 注意权重参数初始化

对模型开发很重要的问题

  • 在(部分)训练数据上过拟合
    • 对于一个小的训练数据集,损失应该基本上接近于 0,对于一个表达神经网络
    • 如果神经网络不能过拟合单个数据点,那是错误的
  • 仔细检查损失函数
  • 仔细检查可视化

5.3 图神经网络资源

CS224W图机器学习笔记自用:GNN Augmentation and Training_第48张图片
论文阅读:
CS224W图机器学习笔记自用:GNN Augmentation and Training_第49张图片

你可能感兴趣的:(CS224W图神经网络,机器学习,人工智能,算法)