【论文阅读-GSAPool】Structure-Feature based Graph Self-adaptive Pooling

论文地址:https://arxiv.org/pdf/2002.00848.pdf
代码地址:https://github.com/psp3dcg/GSAPool
来源:WWW2020

近年来,提出了各种处理Graph的方法。但是,这些方法大多侧重于节点特征的聚合,而很少关注图池化。而现有的基于top-k的图池化方法也存在一些问题。首先,为了构建池图拓扑,目前的top-k选择方法仅从单个角度评估节点的重要性,过于简单和客观。其次,在池化过程中,未选择节点的特征信息直接丢失,不可避免地导致大量图特征信息丢失。
这篇论文提出的一种新的图自适应池化方法主要解决以下问题:
(1)为了构建合理的池化图拓扑,同时考虑了图的结构和特征信息,在节点选择上提供了更高的准确性和客观性;
(2)为了使合并的节点包含足够有效的图信息,在丢弃不重要节点之前对节点特征信息进行聚合;因此,被选中的节点包含来自邻居节点的信息,可以增强未被选中节点特征的使用。

模型架构

【论文阅读-GSAPool】Structure-Feature based Graph Self-adaptive Pooling_第1张图片
延续了SAGPool的分层池化模型架构,便于进行实验对比。

代码逻辑

【论文阅读-GSAPool】Structure-Feature based Graph Self-adaptive Pooling_第2张图片

卷积层

ChebConv直接使用拉普拉斯矩阵作为卷积算子。在不进行特征向量分解的情况下,减少了参数的数量,加快了计算速度。
GCNConv将卷积扩展到图结构数据中,可以获得更好的图表示,在半监督任务中表现良好。GraphSAGE通过聚集邻域内的节点特征信息来生成节点嵌入。
GAT采用注意机制,在聚集过程中计算相邻节点的注意分数作为特征信息的权重值。

实验中使用GCNConv

图池化层

【论文阅读-GSAPool】Structure-Feature based Graph Self-adaptive Pooling_第3张图片

SBTL(The structure-based topology learning)

通常,图中包含大量的节点和边,这些节点和边表示丰富的结构信息。因此,根据节点的结构信息来评估节点的重要性是有效的。考虑到GCNConv考虑了结构信息,该方法适合于评估节点的重要性。GCNConv的表达式如下:
GCN
score函数也可以使用其他gnn,如ChebConv、GraphSAGE、GAT。本论文中使用的是GCNConv。
代码实现:

#SBTL
score_s = self.sbtl_layer(x,edge_index).squeeze()#sbtl_layer是一个(128,1)的GCNConv

FBTL(The feature-based topology learning)

在图数据中,节点通常包含特征信息。利用节点特征信息进行评价是很重要的,因为节点的特征可以在很大程度上代表节点。节点特性的影响不能被忽略。
采用了MLP作为节点特征提取方法,因为MLP具有较好的自适应特征信息聚合能力。表达式如下:
FBTL
代码实现:

 score_f = self.fbtl_layer(x).squeeze() #这里的fbtl_layer是一个(128,1)的Linear层

SFTL(The structurefeature topology learning )

组合两种节点评分方式,表示如下:
SFTL
代码实现:

score = score_s*self.alpha + score_f*(1-self.alpha)
score = score.unsqueeze(-1) if score.dim()==0 else score

Fusion

【论文阅读-GSAPool】Structure-Feature based Graph Self-adaptive Pooling_第4张图片
不同的节点特征融合策略:
(a.)合并后的节点只保留自己的特征,
(b.)合并后的节点在1跳邻居节点内的聚集特征,
(c.)合并后的节点在n跳邻居节点内的聚集特征
带箭头的边缘表示融合过程中特征信息流动的方向。

实验部分用了1跳邻居的GAT

#fusion
if(self.fusion_flag == 1):
	x = self.fusion(x, edge_index)	#fusion是GAT层

Readout层

同SAGPool,对一个子图取全局池化和平均池化作为该子图的embding

x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

更新子图,保持连通性

#更新节点
x = x[perm] * score[perm].view(-1, 1)
x = self.multiplier * x if self.multiplier != 1 else x
batch = batch[perm]
#更新邻接表
edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0))

分类

x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout_ratio, training=self.training)
x = F.relu(self.lin2(x))
x = F.log_softmax(self.lin3(x), dim=-1)

总结

这篇论文也是采用基于TOPK的节点池化。主要在SAGPool的基础上增加了节点的MLP得分,考虑了节点的属性信息。同时,在TOPK丢弃节点前,先进行GAT操作,保留了丢弃节点的信息。

你可能感兴趣的:(论文阅读,GCN)