Task03 基于图神经网络的节点表征学习

参考链接:https://github.com/datawhalechina/team-learning-nlp/blob/master/GNN/Markdown%E7%89%88%E6%9C%AC/5-%E5%9F%BA%E4%BA%8E%E5%9B%BE%E7%A5%9E%E7%BB%8F%E7%BD%91%E7%BB%9C%E7%9A%84%E8%8A%82%E7%82%B9%E8%A1%A8%E5%BE%81%E5%AD%A6%E4%B9%A0.md

一、引言

在图节点预测或边预测任务中,需要先构造节点表征(representation),节点表征是图节点预测和边预测任务成功的关键。
我们要根据节点的属性(可以是类别型、也可以是数值型)、边的信息、边的属性(如果有的话)、已知的节点预测标签,对未知标签的节点做预测。

二、MLP在图节点分类中的应用

多层感知机(MLP,Multilayer Perceptron)也叫人工神经网络(ANN,Artificial Neural Network),除了输入输出层,它中间可以有多个隐层,最简单的MLP只含一个隐层,即三层的结构,如下图:


从上图可以看到,多层感知机层与层之间是全连接的。多层感知机最底层是输入层,中间是隐藏层,最后是输出层。
理论上,MLP能够仅根据文件的内容,即它的词包特征表示来推断文件的类别,而无需考虑文件之间的任何关系信息。在一个简单的MLP中,该网络只对输入节点的特征进行操作,它在所有节点之间共享权重。
部分代码实现:

import numpy as np
import pandas as pd
import torch

class MLP(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(MLP, self).__init__()
        torch.manual_seed(12345)
        self.lin1 = Linear(dataset.num_features,hidden_channels)
        self.lin2 = Linear(hidden_channels,dataset.num_classes)
        
    def forward(self, x):
        x = self.lin1(x)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x
    
model = MLP(hidden_channels=16)
print(model)

三、GCN及其在图节点分类任务中的应用

GCN(Graph Convolutional Network) 图卷积网络处理的数据是图结构,属于拓扑结构。
图的特征提取的方法有空域和频域,GCN使用频域进行特征提取。

  • 空域:很直观,直接用相应顶点连接的neighbors来提取特征。
  • 频域:在图上进行信号处理的变换。

缺点:

  • 这个模型对于同阶的邻域上分配给不同的邻居的权重是完全相同的,这一点限制了模型对于空间信息的相关性的捕捉能力。
  • GCN结合临近节点特征的方式和图的结构依依相关,这局限了训练所得模型在其他图结构上的泛化能力。

部分代码实现:

from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_features,hidden_channels)
        self.conv2 = GCNConv(hidden_channels,dataset.num_classes)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x
    
model = GCN(hidden_channels=16)
print(model)

四、GAT及其在图节点分类任务中的应用

GraphAttentionNetwork(GAT)图注意网络提出了用注意力机制对邻近节点特征加权求和。 邻近节点特征的权重完全取决于节点特征,独立于图结构。GAT和GCN的核心区别在于如何收集并累和距离为1的邻居节点的特征表示。图注意力模型GAT用注意力机制替代了GCN中固定的标准化操作。本质上,GAT只是将原本GCN的标准化函数替换为使用注意力权重的邻居节点特征聚合函数。
优点:

  • 在GAT中,图中的每个节点可以根据邻节点的特征,为其分配不同的权值。
  • 引入注意力机制之后,只与相邻节点有关。
    部分代码实现:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GAT, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GATConv(dataset.num_features,hidden_channels)
        self.conv2 = GATConv(hidden_channels,dataset.num_classes)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

五、总结

在节点表征的学习中,MLP节点分类器只考虑了节点自身属性,忽略了节点之间的连接关系,它的结果是最差的;而GCN与GAT节点分类器,同时考虑了节点自身属性与周围邻居节点的属性,它们的结果优于MLP节点分类器。从中可以看出邻居节点的信息对于节点分类任务的重要性。
GCN图神经网络与GAT图神经网络的区别在于采取的归一化方法不同

  • 前者根据中心节点与邻居节点的度计算归一化系数,后者根据中心节点与邻居节点的相似度计算归一化系数。
  • 前者的归一化方式依赖于图的拓扑结构,不同节点其自身的度不同、其邻居的度也不同,在一些应用中可能会影响泛化能力。
  • 后者的归一化方式依赖于中心节点与邻居节点的相似度,相似度是训练得到的,因此不受图的拓扑结构的影响,在不同的任务中都会有较好的泛化表现。

你可能感兴趣的:(Task03 基于图神经网络的节点表征学习)