大家好,我是对白。
了解图神经网络的朋友对于深层GNN中的过平滑问题一定不陌生,随着网络层数的增加,模型的效果反而急剧下降,令人心痛。回忆一下,常见的解决过平滑的方案有DropEdge、基于残差的方法还有Normalization等,但效果却不尽人意。
今天在KDD2022看到一篇有趣的文章,它没有止步于GNN中的过平滑问题,而是从另一个新视角去思考深层网络效果骤降的问题——特征维度的过相关问题。所谓特征维度的过相关,顾名思义,指我们所学习到的特征维度之间高度相关,意味着高冗余以及学习到的维度编码的信息较少,从而损害下游任务效果。文中,作者不仅从理论与实践证明了特征维度相关问题的重要性,还提出方案DeCorr,分别基于显示特征维度与最大互信息两种方法实现去相关任务。
从表中可以看出,所提出的方案对于相同的模型层数,在大多数情况下可以实现最佳性能,并显著减缓性能下降。例如,在Cora数据集中,DeCorr分别将15层GCN和30层GCN提高了58.9%和60.3%。
下面就带大家一起领略这个算法的奇妙之处~
论文标题:Feature Overcorrelation in Deep Graph Neural Networks: A New Perspective
论文链接:https://doi.org/10.1145/3534678.3539445
代码链接:https://github.com/ChandlerBang/DeCorr
本文由IBM和密歇根州立大学发表于KDD2022**。**在具体看DeCorr之前,我们有必要先探索一下过平滑与过相关的关系,以便更好地设计模型。
为了评估具体的过相关与过平滑程度,我们采用Corr和SMV作为评估指标,具体公式为:
我们要明确:过相关和过平滑既不相同也不独立。过平滑指的是节点表示之间的相似性,通过节点平滑度来衡量,而过度相关则是通过维度相关来衡量,二者本质不同。Figure 6 中可以看出,在Pubmed和CoauthorCS上,随着Corr值的增加,SMV并未变化较多。
另一方面,它们又高度相关。
过相关和过平滑都使得学习到的表示编码信息量更少,损害下游任务性能;
两种情况都由GNN模型中的多次传播引起,极端情况下的过平滑也会出现过相关的问题。
文章所提出的DeCorr目的在于解决深层GNN中的过相关问题,下面具体来看所提出的方法细节。
方法1:显式特征维数去相关:最小化表示中维数之间相关性。为简单起见,文章采用协方差替代皮尔逊相关系数。给定一组特征维度,我们的目标是最小化损失函数:
最小化第一项会降低不同特征维度之间的协方差,当第一项值为0时,维度之间将不相关。通过最小化第二项,将每个维度的范数(减去平均值后)推到1,然后我们将上式改写为:
此处注意,由于梯度的时间复杂度为,而在真实应用场景中,图的节点数量众多,它是不可扩展的。为此,我们采用蒙特卡洛采样个节点来估计协方差的等概率节点。这样,梯度计算的时间复杂度降为,随图的大小线性增加。
结合去相关损失,最终的损失函数为:
通过最小化损失函数,我们可以明确强制每个层后的表示减少相关性,从而缓解过相关问题。
方法2:互信息最大化:最大化输入和表示之间的互信息,从而使特征更加独立。采用互信息的动机来自ICA,它的原理旨在学习维度相关性较低的表示,同时最大化输入和表示之间的MI。由于深层GNN在表示中的编码信息量更少,MMI可以确保即使模型堆叠了很多层,学习的表示也可以保留来自输入的部分信息。
MI最大化过程公式为:
由于在神经网络的背景下,估计变量与的MI非常困难,这里采用一个很nice的方法——通过样本有效估计高维连续数据互信息(MINE)。具体地方法为,我们通过训练分类器来区分来自联合分布和的样本对来估计互信息的下限。因此,我们的训练目标为:
分类器建模为:
在实际应用中,我们从每个batch中从联合分布采样去估计目标函数的第一项,然后在batch中打乱去生成“负对”去估计第二项。
为减少损耗,仅在每层去应用以加速训练过程:
最终完整的模型损失函数为:
与普通GNN模型相比,所提出模型的额外复杂度可忽略不计;额外时间复杂度为。具体推导过程详见论文。
如图所示,在删除测试集和验证集中的节点特征的情况下(这种情况下一般深层GNN的效果要比浅层好),Table 2列出12种情况,DeCorr在8种情况下实现了最佳性能,显著优于浅层GNN。例如,在Pubmed数据集上,DeCorr在GCN、GAT和ChebyNet上分别实现了36.9%、33.9%和19.5%的改进。
为探索DeCorr实现性能改进的原因,论文绘制了在训练过程中的Corr、SMV以及精确度变化图,证明了深层GNN种过相关问题的重要性。
结合其他深度学习的方法实验结果如下:
论文针对深层GNN网络种效果下降的问题,考虑了除过平滑外的一个新问题——过相关。论文分析了过相关问题的重要性,探索了背后的原因,并设计了一个通用框架DeCorr去改善过相关现象。可以看出,DeCorr在改进深层GNN性能方面达到较好效果,此外,在其他应用场景中也具有潜力。
更多精彩内容请关注 微信公众号 「对白的算法屋」:
清华计算机硕士,现BAT算法工程师,秋招拿过8家大厂算法岗SSP offer;日常分享前沿AI算法、技术干货和职场感悟,帮助你少走弯路进大厂!
关注后回复 【算法】领取AI算法工程师学习路线;
关注后回复 【刷题】领取谷歌师兄的刷题笔记《LeetCode算法题解+代码》;
关注后回复 【书单】领取我精心整理的1000本计算机类的书单,包含编程语言、云计算、大数据、AI、职场晋升和大厂面试题汇总等;
关注后回复 【人工智能】领取我整理的20套机器学习资源,包含ML/DL/CV/NLP等;
关注后回复 【对白笔记】领取我的七本原创算法笔记,它帮助我在60场笔面试中获得100%通过率!
…
更多精彩福利干货,期待您的关注 ~