基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达4491ca30466754eed5ce4b40776441eb.jpeg

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第1张图片

【导读】胶囊图神经网络(CapsGNN)是在GNN启发下诞生了基于图片分类的新框架。CapsGNN在10个数据集中的6个的表现排名位居前两名。与所有其他端到端架构相比,CapsGNN在所有社交数据集中均名列首位。

本日Reddit上热议的一个话题是名为“胶囊图神经网络”(CapsGNN)的新框架。从名字不难看出,它是受图神经网络(GNN)的启发,在其基础上改进而来的成果。

CapsGNN框架的作者为新加坡南洋理工大学电气与电子工程学院的Zhang Xinyi和Lihui Chen,该研究的论文将在ICLR 2019上发表。

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第2张图片

目前,从图神经网络(GNN)中学到的高质量节点嵌入已经应用于各种基于节点的应用程序中,其中一些程序已经实现了最先进的性能。不过,当应用程序用GNN学习的节点嵌入来生成图形嵌入时,标量节点表示可能不足以有效地保留节点或图形的完整属性,从而导致图形嵌入的性能达不到最优。

胶囊图神经网络(CapsGNN)受到了胶囊神经网络的启发,利用胶囊的概念来解决现有基于GNN的图嵌入算法的缺点。CapsGNN以胶囊形式对节点特征进行提取,利用路由机制来捕获图形级别的重要信息。因此,模型会为每个图生成多个嵌入,从多个不同方面捕获图的属性。

CapsGNN中包含的注意力模块可用于处理各种尺寸的图,让模型能够专注处理图的关键部分。通过对10个图结构数据集的广泛评估表明,CapsGNN具有强大的机制,可通过数据驱动捕获整个图的宏观属性。在几个图分类任务上的性能优于其他SOTA技术。

胶囊图神经网络基本架构

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第3张图片

上图所示为CapsGNN的简化版本。它由三个关键模块组成:1)基本节点胶囊提取模块:GNN用于提取具有不同感受野的局部顶点特征,然后在该模块中构建主节点胶囊。2)高级图胶囊提取模块:融合了注意力模块和动态路由,以生成多个图胶囊。3)图分类模块:再次利用动态路由,生成用于图分类的类胶囊。

注意力模块

在CapsGNN中,基于每个节点提取主胶囊,即主胶囊的数量取决于输入图的大小。在这种情况下,如果直接应用路由机制,则生成的高级别的胶囊的值将高度依赖于主胶囊的数量(图大小),这种情况并不理想。因此,实验引入一个注意力模块来解决这个问题。

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第4张图片注意力模块架构。首先压平主胶囊,利用两层全连接神经网络产生每个胶囊的注意力值。利用基于节点的归一化(对每行进行归一化)来生成最终注意力值。将标准化值与主胶囊相乘来计算标度胶囊。

实验设置与结果

我们验证了从CapsGNN中提取的图嵌入与大量SOTA方法的性能,与一些经典方法的最优性能做了对比。此外还进行了实验研究,评估胶囊对图编码特征效率的影响。我们对生成的图/类胶囊进行了简要分析。实验结果和分析如下所示。

表1为生物数据集的实验结果,表2为社会数据集的实验结果。对于每个数据集,以粗体突出显示前2个准确度。

与所有其他算法相比,CapsGNN在10个数据集中的6个的表现排名位居前两名,并且在其他数据集上也实现了基本相当的结果。与所有其他端到端架构相比,CapsGNN在所有社交数据集中均名列首位。

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第5张图片

表1:生物数据集的实验结果

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第6张图片

表2:社交数据集的实验结果

胶囊的效率

在胶囊的效率测试实验中,GNN的层数设置为L = 3,每层的通道数都设置为Cl = 2。通过调整节点的维度(dn)、图(dg)、胶囊和图形、胶囊的数量(P)来构造不同的CapsGNN。

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第7张图片

表3:胶囊效率评估实验中经过测试的体系结构详细信息

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第8张图片

图3:特征表示效率的比较。横轴表示测试架构的设置,纵轴表示NCI1的分类精度。

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第9张图片

图胶囊的可视化

基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第10张图片

分类胶囊的可视化

胶囊图网络:基于GNN的高效快捷的新框架

CapsGNN是一个新框架,将胶囊理论融合到GNN中,来实现更高效的图表示学习。该框架受CapsNet的启发,在原体系结构中引入了胶囊的概念,在从GNN提取的节点特征的基础上,以向量的形式提取特征。

利用CapsGNN,一个图可以表示为多个嵌入,每个嵌入都可以捕获不同方面的图属性。生成的图形和类封装不仅可以保留与分类相关的信息,还可以保留关于图属性的其他信息,这些信息可能在后续流程中用到。CapsGNN是一种新颖、高效且强大的数据驱动方法,可以表示图形等高维数据。

与其他SOTA算法相比,CapsGNN模型在10个图表分类任务中有6个成功实现了更好或相当的性能,在社交数据集上的表现尤其显眼。与其他类似的基于标量的体系结构相比,CapsGNN在编码特征方面更有效,这对于处理大型数据集非常有用。

关于开源代码和模型的一些补充信息

运行环境

代码库在Python 3.5.2中实现。用于开发的软件包版本如下:

networkx          1.11
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             0.4.1
torch-scatter     1.1.2
torch-sparse      0.2.2
torch-cluster     1.2.4
torch-geometric   1.0.3
torchvision       0.2.1

数据集

代码会从input文件夹中获取训练图,图存储形式为JSON。用于测试的图也存储为JSON文件。每个节点id和节点标签必须从0开始索引。字典的键是存储的字符串,以使JSON能够序列化排布。

每个JSON文件都具有以下的键值结构:

{"edges": [[0, 1],[1, 2],[2, 3],[3, 4]],
 "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"},
 "target": 1}

边缘键(edges key)具有边缘列表值,用于描述连接结构。标签键具有每个节点的标签,这些标签存储为字典- 在此嵌套字典中,标签是值,节点标识符是键。目标键具有整数值,该值代表了类成员资格。

输出

预测结果保存在output目录中。每个嵌入都有一个标题和一个带有图标识符的列。最后,预测会按标识符列排序。

训练CapsGNN模型由src /main.py脚本处理,该脚本提供以下命令行参数。

输入和输出选项

--training-graphs   STR    Training graphs folder.      Default is `dataset/train/`.
  --testing-graphs    STR    Testing graphs folder.       Default is `dataset/test/`.
  --prediction-path   STR    Output predictions file.     Default is `output/watts_predictions.csv`.

模型选项

--epochs                      INT     Number of epochs.                  Default is 10.
  --batch-size                  INT     Number fo graphs per batch.        Default is 32.
  --gcn-filters                 INT     Number of filters in GCNs.         Default is 2.
  --gcn-layers                  INT     Number of GCNs chained together.   Default is 5.
  --inner-attention-dimension   INT     Number of neurons in attention.    Default is 20.  
  --capsule-dimensions          INT     Number of capsule neurons.         Default is 8.
  --number-of-capsules          INT     Number of capsules in layer.       Default is 8.
  --weight-decay                FLOAT   Weight decay of Adam.              Defatuls is 10^-6.
  --lambd                       FLOAT   Regularization parameter.          Default is 1.0.
  --learning-rate               FLOAT   Adam learning rate.                Default is 0.01.

论文地址:

https://openreview.net/pdf?id=Byl8BnRcYm

Github相关资源:

https://github.com/benedekrozemberczki/CapsGNN#outputs

好消息!

小白学视觉知识星球

开始面向外开放啦


基于GNN,强于GNN:胶囊图神经网络的PyTorch实现 | ICLR 2019_第11张图片

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。


下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。


下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。


交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

你可能感兴趣的:(神经网络,算法,大数据,编程语言,python)