GNN的通俗理解

GNN流程:

1、聚合

2、更新

3、循环

假定现在的数据环境如下图所示,A、B、C、D、E代表节点,互相之间的边连线代表节点间的关系,节点旁边的括号数字代表特征。
GNN的通俗理解_第1张图片

1、聚合

拿节点A来举例:

1)单纯靠A自己一个节点的特征无法准确判断A节点的类别。

2)A跟B、C、D有联系,通过B、C、D可以判断A。(类似近朱者赤,比方说判断A同学学习成绩好不好,但看A来讲不知道,但B、C、D都好,那大概率A的成绩也好)

经过一次聚合后,聚合得到的信息:

邻居信息N=a·(2,2,2,2,2) + b· (3,3,3,3,3)+ c·(4,4,4,4,4)
//a、b、c为常数,可以通过模型训练来定,也可以自己手动定
(而且经常作为论文的改进点,比方说B对A来说很重要,常数b就设的大一点)

总结:把邻居的特征信息贴到自己身上来,作为自身特征信息的补足。

2、更新

GNN的通俗理解_第2张图片
一次聚合完之后:

A的信息=σ(W*((1,1,1,1,1)+ α*N))

解释:
1、A自己的特征(1,1,1,1,1)加上α倍的邻居特征信息。
2、乘权重W 。
3、最后乘一个激活函数σ

// α可以用attention或者自定义一个数,很多文章创新点就是α怎么选取。
// W是模型里需要训练的参数
// σ是激活函数(relu、sigmoid等)

3、循环

GNN的通俗理解_第3张图片

一次聚合后:

A包含B、C、D的信息
B包含A、C的信息
C包含A、B、D、E的信息
D包含A、C的信息
E包含C的信息

第二次聚合后:

(以A为例)
在A聚合C的时候,因为C中有上一次聚合E的信息,所以此时A获得了二阶邻居E的信息。

依次往后类推…
第三次聚合可以的到第三层邻居的信息
第n次聚合可以得到第n层邻居的信息。

应用

通过聚合更新,我们能够得到每个节点的表达,也就是特征feature,此时:

1、节点分类:
可以直接拿去分类,训练的时候,每个节点是一个样本,可以计算loss,优化权重W

2、关联预测:
可以类别分类,最简单的方式拿两个节点的特征拼一起,一样的计算loss做优化。

总结:输入A的特征和整个graph的结构,得到包含所有信息的A节点的最终信息,然后拿最终信息去分类、回归等等。说人话就是一个提取特征的方法。

你可能感兴趣的:(算法,机器学习,深度学习,人工智能)