参考资料:https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html。
图卷积算子(operator)可以被表示为消息传递机制。
x i ′ = U p d a t e ( x i , A g g r e g a t e j ∈ N ( i ) ( M e s s a g e ( x i , x j , e j , i ) ) ) \mathbf{x}_i^{\prime} = \mathbf {Update} \left( \mathbf{x}_i, \mathbf {Aggregate}_{j \in \mathcal{N}(i)} \left( \mathbf{Message} \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right)\right) xi′=Update(xi,Aggregatej∈N(i)(Message(xi,xj,ej,i)))
常规的阶段可以被细分为Linear、Message、Aggregate和Update这四个阶段,这里的Message+Aggregate可以看作是MPNN中的聚合阶段,Linear+Update可以看作是MPNN中的更新阶段:
在PyG中,GConv层会首先将特征矩阵X通过Linear层做一个特征维度的变换,然后显示调用self.propagate
方法进行消息传递操作(即便是自定义的图卷积层也必须遵循),而self.propagate
又会去调用self.message
、self.aggregate
和self.update
方法依次执行各个阶段的操作。
除此之外,一些空域卷积模型还有Sample阶段,如GraphSAGE。在Sample阶段,每个batch都会采样k-hop的邻居,并且每个GConv层都是使用k个采样集合迭代的进行卷积操作。
JK-Nets为了提高表达能力,提出了层聚合(layer-aggregate)的概念,它的主要思想是在原有卷积层后对所有的中间层进行再次聚合。
ChebNet、GCN、GAT、GraphSAGE、JK-Nets。
ChebNet是频域卷积,其他模型是空域卷积。除了JK-Nets外,其余的模型均使用了2层的结构。GraphSAGE模型每层采样2-hop的邻居节点。
修改了MessagePassing类和不同模型的GConv层,搭建GNN-Net。各阶段的执行时间是每个epoch训练的平均时间(GraphSAGE每个epoch包含很多个batch)。
代码上传到了github:https://github.com/ytchx1999/PyG-GNN-Test
各阶段的执行时间/ms | Sample | Linear | Message | Aggregate | Update | layer-aggregate* |
---|---|---|---|---|---|---|
ChebNet(2层) | ❎ | 2.2874 | 0.0436 | 0.2163 | 0.0012 | ❎ |
GCN(2层) | ❎ | 2.2451 | 0.0403 | 0.1350 | 0.0012 | ❎ |
GAT(2层) | ❎ | 2.2947 | 2.8942 | 0.1364 | 0.0013 | ❎ |
GraphSAGE(2层、minibatch) | 414.3549 | 24.4504 | 0.0173 | 10.8453 | 1.0079 | ❎ |
JK-Nets(6层) | ❎ | 2.3020 | 0.1121 | 0.3224 | 0.0037 | 0.0842 |
各阶段的执行时间/ms | Sample | Linear | Message | Aggregate | Update | layer-aggregate* |
---|---|---|---|---|---|---|
ChebNet(2层) | ❎ | 2.1374 | 40.1639 | 224.2229 | 0.0041 | ❎ |
GCN(2层) | ❎ | 1.1456 | 0.3615 | 0.8195 | 0.0025 | ❎ |
GAT(2层) | ❎ | 2.3111 | 5.5557 | 4.9341 | 0.0029 | ❎ |
GraphSAGE(2层、minibatch) | 72.9241 | 2.1413 | 0.0490 | 25.7912 | 1.1896 | ❎ |
JK-Nets(6层) | ❎ | 1.2503 | 0.9868 | 3.4509 | 0.0069 | 1.4495 |
各模型不同阶段的运算(除了ChebNet),均写成node-wise的形式。
各阶段的运算 | Sample | Linear+Message | Aggregate | Update | layer-aggregate* |
---|---|---|---|---|---|
ChebNet(2层) | ❎ | Z ( k ) ⋅ Θ ( k ) \mathbf{Z}^{(k)} \cdot \mathbf{\Theta}^{(k)} Z(k)⋅Θ(k) | sum | 无 | ❎ |
GCN(2层) | ❎ | Θ 1 d ^ j d ^ i x j \mathbf{\Theta}\frac{1}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j Θd^jd^i1xj | sum | self-loop | ❎ |
GAT(2层) | ❎ | α i , j Θ x j \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j} αi,jΘxj | sum | self-loop | ❎ |
GraphSAGE(2层、minibatch) | 依次采样k-hop邻居 | W 2 ⋅ x j \mathbf{W_2} \cdot \mathbf{x}_j W2⋅xj | mean | W 1 x i + m N ( i ) \mathbf{W}_1 \mathbf{x}_i +m_{N(i)} W1xi+mN(i) | ❎ |
JK-Nets(6层) | ❎ | Θ 1 d ^ j d ^ i x j \mathbf{\Theta}\frac{1}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j Θd^jd^i1xj | sum | self-loop | max |
【以下均是基于PyTorch Geometric的源码、结合实验现象进行的讨论】
属性数据(特征信息)–>特征矩阵X:[num_nodes, num_node_features]
结构数据(邻居/边信息)–>COO格式的边表edge_index:[2, num_edges]
——为后面Aggregate阶段的scatter操作埋下了伏笔!
Linear:主要是对特征维度进行变换。
Message:对特征值进行归一化操作。
Aggregate:我认为这才是GNN最核心的部分(区别于CNN),核心算子可以归纳为scatter(归约)操作。专门由torch_scatter实现(递归地阅读源代码,发现原理还是torch中的scatter)。https://pytorch-scatter.readthedocs.io/en/latest/。
# sum
out[ index[i] ][j] += src[i][j]
# max
out[ index[i] ][j] = max(src[i][j])
Update:可以分为组合阶段和整形变换。
Sample:主要进行k-hop邻居的采样操作(GraphSAGE)。
layer-aggregate*:对每个节点所有的中间层节点表示进行聚合(JK-Nets)。
GraphSAGE在GPU中dataloader速度比CPU慢了很多,目前正在查找原因。
通过查看源代码和资料,初步认为出现这种情况的原因是dataloader使用CPU进行采样处理,然后得到的采样结果再送往GPU进行训练,这种数据传输导致了每个batch都需要花费一定的时间。
知乎:pytorch dataloader数据加载占用了大部分时间,各位大佬都是怎么解决的?
采样之后的情况
for batch_size, n_id, adjs in train_loader:
print('batch_size:',batch_size)
print('n_id shape:',n_id.shape)
print('adjs length:',len(adjs))
for i,(edge_index, e_id, size) in enumerate(adjs):
print('edge_index_shape:{} e_id_shape:{} size:{}'.format(edge_index.shape,e_id.shape,size))
print('---------------')
batch是需要训练的节点数量。训练节点一共140个,每个batch取16个节点,因此共分为9组,前8组batch大小为16,最后一组为12。
n_id是2-hop总共采样的节点数量。
adjs是元祖,包括(edge_index, e_id, size),长度为跳数2。
batch_size: 16
n_id shape: torch.Size([237])
adjs length: 2
edge_index_shape:torch.Size([2, 336]) e_id_shape:torch.Size([336]) size:(237, 68)
edge_index_shape:torch.Size([2, 53]) e_id_shape:torch.Size([53]) size:(68, 16)
---------------
batch_size: 16
n_id shape: torch.Size([240])
adjs length: 2
edge_index_shape:torch.Size([2, 333]) e_id_shape:torch.Size([333]) size:(240, 68)
edge_index_shape:torch.Size([2, 56]) e_id_shape:torch.Size([56]) size:(68, 16)
---------------
batch_size: 16
n_id shape: torch.Size([271])
adjs length: 2
edge_index_shape:torch.Size([2, 429]) e_id_shape:torch.Size([429]) size:(271, 88)
edge_index_shape:torch.Size([2, 73]) e_id_shape:torch.Size([73]) size:(88, 16)
---------------
batch_size: 16
n_id shape: torch.Size([275])
adjs length: 2
edge_index_shape:torch.Size([2, 417]) e_id_shape:torch.Size([417]) size:(275, 76)
edge_index_shape:torch.Size([2, 67]) e_id_shape:torch.Size([67]) size:(76, 16)
---------------
batch_size: 16
n_id shape: torch.Size([202])
adjs length: 2
edge_index_shape:torch.Size([2, 351]) e_id_shape:torch.Size([351]) size:(202, 75)
edge_index_shape:torch.Size([2, 64]) e_id_shape:torch.Size([64]) size:(75, 16)
---------------
batch_size: 16
n_id shape: torch.Size([221])
adjs length: 2
edge_index_shape:torch.Size([2, 317]) e_id_shape:torch.Size([317]) size:(221, 64)
edge_index_shape:torch.Size([2, 51]) e_id_shape:torch.Size([51]) size:(64, 16)
---------------
batch_size: 16
n_id shape: torch.Size([273])
adjs length: 2
edge_index_shape:torch.Size([2, 475]) e_id_shape:torch.Size([475]) size:(273, 93)
edge_index_shape:torch.Size([2, 78]) e_id_shape:torch.Size([78]) size:(93, 16)
---------------
batch_size: 16
n_id shape: torch.Size([255])
adjs length: 2
edge_index_shape:torch.Size([2, 383]) e_id_shape:torch.Size([383]) size:(255, 77)
edge_index_shape:torch.Size([2, 63]) e_id_shape:torch.Size([63]) size:(77, 16)
---------------
batch_size: 12
n_id shape: torch.Size([240])
adjs length: 2
edge_index_shape:torch.Size([2, 333]) e_id_shape:torch.Size([333]) size:(240, 72)
edge_index_shape:torch.Size([2, 60]) e_id_shape:torch.Size([60]) size:(72, 12)
---------------
https://pytorch-geometric.readthedocs.io/en/latest/notes/sparse_tensor.html
在图比较大或比较稠密的时候,使用边表edge_index这种图结构数据,需要在Aggregate阶段显式的矩阵化src(x_j),这会导致很高的内存占用,反而不如使用稀疏矩阵(SparseTensor)进行存储和运算了。
### 各模型的网络结构
# ChebNet
ChebNet(
(conv1): ChebConv(1433, 16, K=2, normalization=sym)
(conv2): ChebConv(16, 7, K=2, normalization=sym)
)
# GCN
GCNNet(
(conv1): GCNConv(1433, 16)
(conv2): GCNConv(16, 7)
)
# GAT
GATNet(
(conv1): GATConv(1433, 8, heads=8)
(conv2): GATConv(64, 7, heads=1)
)
# GraphSAGE
SAGENet(
(convs): ModuleList(
(0): SAGEConv(1433, 16)
(1): SAGEConv(16, 7)
)
)
# JK-Nets
JKNet(
(conv0): GCNConv(1433, 16)
(conv1): GCNConv(16, 16)
(conv2): GCNConv(16, 16)
(conv3): GCNConv(16, 16)
(conv4): GCNConv(16, 16)
(conv5): GCNConv(16, 16)
(jk): JumpingKnowledge(max)
(fc): Linear(in_features=16, out_features=7, bias=True)
)