GNN典型模型的各阶段执行时间与算子分析

GNN模型的阶段划分

参考资料:https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html。
图卷积算子(operator)可以被表示为消息传递机制。
x i ′ = U p d a t e ( x i , A g g r e g a t e j ∈ N ( i ) ( M e s s a g e ( x i , x j , e j , i ) ) ) \mathbf{x}_i^{\prime} = \mathbf {Update} \left( \mathbf{x}_i, \mathbf {Aggregate}_{j \in \mathcal{N}(i)} \left( \mathbf{Message} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right)\right) xi=Update(xi,AggregatejN(i)(Message(xi,xj,ej,i)))
常规的阶段可以被细分为Linear、Message、Aggregate和Update这四个阶段,这里的Message+Aggregate可以看作是MPNN中的聚合阶段,Linear+Update可以看作是MPNN中的更新阶段:

GNN框架
聚合阶段
更新阶段
Message
Aggregate
Linear
Update
  1. Linear:用来对邻居节点的特征维度进行变换。 一般就是一个Linear层(矩阵乘法)。
  2. Message:用来对邻居节点的特征数值进行归一化。 一般是x_j和norm进行乘法操作,norm可以是注意力权重、归一化权重等等。
  3. Aggregate:聚合上一步处理好的每个节点的邻居节点特征(生成一条消息)。 一般是对于邻居特征矩阵的scatter归约操作。[N,out_dim]
  4. Update:通过上一步生成的消息更新当前层节点的嵌入表示。 可以是MLP,很多模型因为有self-loop所以没有这个阶段。使用 x i x_i xi m N ( i ) m_{N(i)} mN(i)进行更新。

在PyG中,GConv层会首先将特征矩阵X通过Linear层做一个特征维度的变换,然后显示调用self.propagate方法进行消息传递操作(即便是自定义的图卷积层也必须遵循),而self.propagate又会去调用self.messageself.aggregateself.update方法依次执行各个阶段的操作。

除此之外,一些空域卷积模型还有Sample阶段,如GraphSAGE。在Sample阶段,每个batch都会采样k-hop的邻居,并且每个GConv层都是使用k个采样集合迭代的进行卷积操作。

JK-Nets为了提高表达能力,提出了层聚合(layer-aggregate)的概念,它的主要思想是在原有卷积层后对所有的中间层进行再次聚合。

实验中使用的GNN模型

ChebNet、GCN、GAT、GraphSAGE、JK-Nets。
ChebNet是频域卷积,其他模型是空域卷积。除了JK-Nets外,其余的模型均使用了2层的结构。GraphSAGE模型每层采样2-hop的邻居节点。

GNN模型的各阶段执行时间

修改了MessagePassing类和不同模型的GConv层,搭建GNN-Net。各阶段的执行时间是每个epoch训练的平均时间(GraphSAGE每个epoch包含很多个batch)。

代码上传到了github:https://github.com/ytchx1999/PyG-GNN-Test

  • 实验环境:云服务器 + 一块Tesla T4 + PyTorch Geometric。
    • CPU
      型号:Intel® Xeon® Gold 5218 CPU
      内存:128G
      内核:64核
    • GPU
      Tesla T4 *1
      显存:16G
  • 数据集:Cora。
  • 计时单位:ms。
各阶段的执行时间/ms Sample Linear Message Aggregate Update layer-aggregate*
ChebNet(2层) 2.2874 0.0436 0.2163 0.0012
GCN(2层) 2.2451 0.0403 0.1350 0.0012
GAT(2层) 2.2947 2.8942 0.1364 0.0013
GraphSAGE(2层、minibatch) 414.3549 24.4504 0.0173 10.8453 1.0079
JK-Nets(6层) 2.3020 0.1121 0.3224 0.0037 0.0842
  • 本地CPU的实验:
    • Intel® Core™ i5-1038NG7 CPU @ 2.00GHz
    • 内存:16G
    • 内核:4核
各阶段的执行时间/ms Sample Linear Message Aggregate Update layer-aggregate*
ChebNet(2层) 2.1374 40.1639 224.2229 0.0041
GCN(2层) 1.1456 0.3615 0.8195 0.0025
GAT(2层) 2.3111 5.5557 4.9341 0.0029
GraphSAGE(2层、minibatch) 72.9241 2.1413 0.0490 25.7912 1.1896
JK-Nets(6层) 1.2503 0.9868 3.4509 0.0069 1.4495

GNN模型各阶段的算子种类及特征

各模型不同阶段的运算(除了ChebNet),均写成node-wise的形式。

各阶段的运算 Sample Linear+Message Aggregate Update layer-aggregate*
ChebNet(2层) Z ( k ) ⋅ Θ ( k ) \mathbf{Z}^{(k)} \cdot \mathbf{\Theta}^{(k)} Z(k)Θ(k) sum
GCN(2层) Θ 1 d ^ j d ^ i x j \mathbf{\Theta}\frac{1}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j Θd^jd^i 1xj sum self-loop
GAT(2层) α i , j Θ x j \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j} αi,jΘxj sum self-loop
GraphSAGE(2层、minibatch) 依次采样k-hop邻居 W 2 ⋅ x j \mathbf{W_2} \cdot \mathbf{x}_j W2xj mean W 1 x i + m N ( i ) \mathbf{W}_1 \mathbf{x}_i +m_{N(i)} W1xi+mN(i)
JK-Nets(6层) Θ 1 d ^ j d ^ i x j \mathbf{\Theta}\frac{1}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j Θd^jd^i 1xj sum self-loop max

【以下均是基于PyTorch Geometric的源码、结合实验现象进行的讨论】

实验所使用的图数据结构:

属性数据(特征信息)–>特征矩阵X:[num_nodes, num_node_features]
结构数据(邻居/边信息)–>COO格式的边表edge_index:[2, num_edges]——为后面Aggregate阶段的scatter操作埋下了伏笔!

各阶段算子种类及特征分析

Linear:主要是对特征维度进行变换。

  • 这里,为了增加计算的并行性、减少重复计算,同时也为了能够在此阶段不进行遍历操作,整形阶段一般是经过一个Linear层对特征矩阵X的维度进行变换。换句话说就是个矩阵乘法——特征矩阵X * 权重矩阵 Θ \mathbf{\Theta} Θ
  • Linear阶段的算子就是矩阵乘法,是计算密集型的操作。

Message:对特征值进行归一化操作。

  • 同整形阶段的原因,此阶段是对边表edge_index中的所有target_node组成的特征矩阵进行操作,同样是矩阵乘法——target_node特征矩阵X_j * 归一化矩阵Norm。
  • 经过上面的分析,Message阶段因为不进行遍历操作,大部分时间都在进行矩阵的乘法运算。因此,我认为Message阶段的算子就是矩阵乘法,这也就意味着此阶段是计算密集型的操作。

Aggregate:我认为这才是GNN最核心的部分(区别于CNN),核心算子可以归纳为scatter(归约)操作。专门由torch_scatter实现(递归地阅读源代码,发现原理还是torch中的scatter)。https://pytorch-scatter.readthedocs.io/en/latest/。
GNN典型模型的各阶段执行时间与算子分析_第1张图片
GNN典型模型的各阶段执行时间与算子分析_第2张图片

  • 其实在没仔细研究之前,我一直认为Aggregate阶段需要显式的遍历邻居节点,此外,对于为什么要使用边表而不是其他的图数据结构,我心里一直有一些疑问。直到发现这里的scatter操作,顿时感觉豁然开朗了,之前的疑问也都能解释了。
  • scatter的精髓其实是巧妙地利用了边表的索引特性。index其实是edge_index[1],表示target节点的id号。那么矩阵src又是如何得到的呢?——其实是利用src节点的id号edge_index[0]、节点的特征矩阵X构造出来的。这样,使用矩阵,也可以“无形”之中(隐式)建立一种规整的(regular)邻居对应关系。如图所示,将index中相同id号所对应矩阵src的行进行聚合(add/max/mean),就可以得到此id号的节点的邻域聚合结果(有点绕)。
  • 我将整个scatter过程概括为以下的伪代码表示:
# sum
out[ index[i] ][j] += src[i][j]
# max
out[ index[i] ][j] = max(src[i][j])
  • 稍微观察一下就能发现,由于每个节点的邻居聚合互不干扰,scatter操作具有很高的计算并行性。为了加速,部分代码使用C++和CUDA编写。
    GNN典型模型的各阶段执行时间与算子分析_第3张图片
  • 虽然设计的十分巧妙,scatter操作也不可避免的需要频繁地访问节点的特征信息(而非对矩阵进行整体操作),造成细粒度访存。考虑到涉及到运算的地方都是比较简单的element-wise比较操作,所以Aggregate阶段的scatter是访存密集型的操作
  • 这和图计算编程模型(VCPM)中的Scatter阶段遥相呼应。
    GNN典型模型的各阶段执行时间与算子分析_第4张图片

Update:可以分为组合阶段和整形变换。

  • 「注意:具有self-loop的模型没有Update阶段。很多模型为了简化也会省略Update阶段。」
  • 组合阶段:将当前节点的嵌入表示和Aggregate阶段生成的消息进行组合,得到新的节点表示。具有代表性的一类方法是使用skip-connection(拼接,求和,加权求和,max等)进行组合。这个阶段主要是进行element-wise的比较操作。
  • 整形变换:将组合阶段产生的新的节点表示经过Linear层进行整形(维数和任务相关),得到更新后的节点嵌入表示,送入下一层进行训练。这个阶段主要是矩阵乘法运算。
  • 综上,Update阶段(如果有的话)的主要算子还是矩阵乘法,属于计算密集型操作

Sample:主要进行k-hop邻居的采样操作(GraphSAGE)。

  • Sample一般在Message之前进行,为了保证适配性和迁移性,Sample阶段除了要计算k个邻居集合(k-hop),还要计算子图的edge_index等图数据结构并返回,以达到Message阶段的无缝衔接。
  • 为每个batch中的节点采样k-hop邻居需要频繁地访存,这一点显而易见;构造子图的数据结构又需要花费额外的空间和访存时间。相比于前面的两点,构造集合的并集操作花的时间就不值一提了。因此,Sample阶段是访存密集型的操作
    GNN典型模型的各阶段执行时间与算子分析_第5张图片

layer-aggregate*:对每个节点所有的中间层节点表示进行聚合(JK-Nets)。

  • 简单的矩阵拼接或取max的操作,没什么新鲜玩意。后面一般会加一个Linear层整形,此阶段的分析和Update差不太多。

实验结果的分析(根据上面的讨论)

  1. ChebNet、GCN、GAT、JK-Nets都没有Sample阶段,它们在Message阶段花费的时间最长。这是因为Message阶段有矩阵乘法、归一化等操作,GAT甚至还有额外的NN用来得到注意力权重 α i , j \alpha_{i,j} αi,j。除此之外,Aggregate阶段执行scatter操作,虽然是访存密集型,但因为其很好的并行性+CUDA加速+数据集不大,并没有花费很多的时间。这几个模型由于使用了self-loop,所以Update阶段约等于没有。
  2. GraphSAGE模型有Sample阶段。由于Sample阶段不仅需要访问1阶邻居,而是需要一直采样到k阶邻居,再加上需要重新构造子图,花费的时间最长。此外,在一个SAGEConv层中需要迭代地聚合k阶邻居(其他4个模型每层只聚合1阶邻居),所以Message、Aggregate和Update阶段所花费的时间也比其他几个模型要长。

GraphSAGE出现的问题

GraphSAGE在GPU中dataloader速度比CPU慢了很多,目前正在查找原因。
通过查看源代码和资料,初步认为出现这种情况的原因是dataloader使用CPU进行采样处理,然后得到的采样结果再送往GPU进行训练,这种数据传输导致了每个batch都需要花费一定的时间。
知乎:pytorch dataloader数据加载占用了大部分时间,各位大佬都是怎么解决的?

GraphSAGE采样结果记录

采样之前的情况:
在这里插入图片描述

采样之后的情况

for batch_size, n_id, adjs in train_loader:
    print('batch_size:',batch_size)
    print('n_id shape:',n_id.shape)
    print('adjs length:',len(adjs))
    
    for i,(edge_index, e_id, size) in enumerate(adjs):
        print('edge_index_shape:{} e_id_shape:{} size:{}'.format(edge_index.shape,e_id.shape,size))
    
    print('---------------')

batch是需要训练的节点数量。训练节点一共140个,每个batch取16个节点,因此共分为9组,前8组batch大小为16,最后一组为12。

n_id是2-hop总共采样的节点数量。

adjs是元祖,包括(edge_index, e_id, size),长度为跳数2。

  • edge_index:采样后构造的子图的边表
  • e_id:子图边表中的边在全图中的边id号
  • size:采样之后和之前集合中的节点数
batch_size: 16
n_id shape: torch.Size([237])
adjs length: 2
edge_index_shape:torch.Size([2, 336]) e_id_shape:torch.Size([336]) size:(237, 68)
edge_index_shape:torch.Size([2, 53]) e_id_shape:torch.Size([53]) size:(68, 16)
---------------
batch_size: 16
n_id shape: torch.Size([240])
adjs length: 2
edge_index_shape:torch.Size([2, 333]) e_id_shape:torch.Size([333]) size:(240, 68)
edge_index_shape:torch.Size([2, 56]) e_id_shape:torch.Size([56]) size:(68, 16)
---------------
batch_size: 16
n_id shape: torch.Size([271])
adjs length: 2
edge_index_shape:torch.Size([2, 429]) e_id_shape:torch.Size([429]) size:(271, 88)
edge_index_shape:torch.Size([2, 73]) e_id_shape:torch.Size([73]) size:(88, 16)
---------------
batch_size: 16
n_id shape: torch.Size([275])
adjs length: 2
edge_index_shape:torch.Size([2, 417]) e_id_shape:torch.Size([417]) size:(275, 76)
edge_index_shape:torch.Size([2, 67]) e_id_shape:torch.Size([67]) size:(76, 16)
---------------
batch_size: 16
n_id shape: torch.Size([202])
adjs length: 2
edge_index_shape:torch.Size([2, 351]) e_id_shape:torch.Size([351]) size:(202, 75)
edge_index_shape:torch.Size([2, 64]) e_id_shape:torch.Size([64]) size:(75, 16)
---------------
batch_size: 16
n_id shape: torch.Size([221])
adjs length: 2
edge_index_shape:torch.Size([2, 317]) e_id_shape:torch.Size([317]) size:(221, 64)
edge_index_shape:torch.Size([2, 51]) e_id_shape:torch.Size([51]) size:(64, 16)
---------------
batch_size: 16
n_id shape: torch.Size([273])
adjs length: 2
edge_index_shape:torch.Size([2, 475]) e_id_shape:torch.Size([475]) size:(273, 93)
edge_index_shape:torch.Size([2, 78]) e_id_shape:torch.Size([78]) size:(93, 16)
---------------
batch_size: 16
n_id shape: torch.Size([255])
adjs length: 2
edge_index_shape:torch.Size([2, 383]) e_id_shape:torch.Size([383]) size:(255, 77)
edge_index_shape:torch.Size([2, 63]) e_id_shape:torch.Size([63]) size:(77, 16)
---------------
batch_size: 12
n_id shape: torch.Size([240])
adjs length: 2
edge_index_shape:torch.Size([2, 333]) e_id_shape:torch.Size([333]) size:(240, 72)
edge_index_shape:torch.Size([2, 60]) e_id_shape:torch.Size([60]) size:(72, 12)
---------------

思考

https://pytorch-geometric.readthedocs.io/en/latest/notes/sparse_tensor.html
在图比较大或比较稠密的时候,使用边表edge_index这种图结构数据,需要在Aggregate阶段显式的矩阵化src(x_j),这会导致很高的内存占用,反而不如使用稀疏矩阵(SparseTensor)进行存储和运算了。

实验附录

### 各模型的网络结构
# ChebNet
ChebNet(
  (conv1): ChebConv(1433, 16, K=2, normalization=sym)
  (conv2): ChebConv(16, 7, K=2, normalization=sym)
)
# GCN
GCNNet(
  (conv1): GCNConv(1433, 16)
  (conv2): GCNConv(16, 7)
)
# GAT
GATNet(
  (conv1): GATConv(1433, 8, heads=8)
  (conv2): GATConv(64, 7, heads=1)
)
# GraphSAGE
SAGENet(
  (convs): ModuleList(
    (0): SAGEConv(1433, 16)
    (1): SAGEConv(16, 7)
  )
)
# JK-Nets
JKNet(
  (conv0): GCNConv(1433, 16)
  (conv1): GCNConv(16, 16)
  (conv2): GCNConv(16, 16)
  (conv3): GCNConv(16, 16)
  (conv4): GCNConv(16, 16)
  (conv5): GCNConv(16, 16)
  (jk): JumpingKnowledge(max)
  (fc): Linear(in_features=16, out_features=7, bias=True)
)

你可能感兴趣的:(GNN实验,深度学习,python,图神经网络,GNN,PyG)