最近做毕业设计,需要用到图神经网络(以下简称GNN)。由于刚入门GNN,不想看大段的公式和相关论文(然而事实证明该看的永远逃不了),所以怎么办?百度上找呗!因为自己平时用pytorch比较多,所以找到了基于pytorch的图神经网络库,pytorch_geometric(以下简称pyg)。在用这个库的过程中,由于这个库“约定大于配置”的一些特性,也遇到了许多坑,而中文资料中,大多都是直接翻译文档,对一些细节没有做解释。因此将整个过程记录下来,供大家未来参考。
考虑到有许多同学和我最开始一样,就是想知道GNN是干什么的,大致原理是什么,因此这一章将简单介绍GNN
相信大家都知道“图”是什么东西,一堆的节点(每个节点都有各自的特征),中间用箭头连起来。那么我们就很自然地想,这些相连的节点,可能信息上比较相关,可以互补,那我们能不能对这些相关的节点做一定的特征聚合操作呢?比如求和(这就是GIN),取平均(这就是大家常说的GCN),加权求和(这就是GAN)?做特征聚合时,可以只对一阶相邻(也就是直接相邻)节点做聚合,也可以对二阶相邻(最多可以通过一个中间节点间接相连)节点做聚合,由此衍生出各种花里胡哨的GNN……
所以我们说,GNN的核心操作是节点的特征聚合,具体怎么聚合,各个GNN有自己的花样。但是这些花样一定都是依据邻接关系得到的。毕竟脱离了图结构,GNN也没有存在的意义了,对吧?
X l + 1 = a g g r e g a t e ( X l , A ) A 表 示 邻 接 矩 阵 , X l 表 示 第 l 层 的 输 出 X^{l+1}=aggregate(X^l,A)\\ A表示邻接矩阵,X^l表示第l层的输出 Xl+1=aggregate(Xl,A)A表示邻接矩阵,Xl表示第l层的输出
看到这个,我们可以联想一下CNN,一个3x3的卷积核,实际上就是对8邻域做了加权和,对吧?如果把二维图像看成是一张8邻域连通的图,那实际上就是GNN了。
除了聚合特征之外,GNN们通常还带有一个单节点的特征变换运算,这个运算可以是单纯的线性变换,可以带有非线性激活函数等,可以认为是对单节点的特征做了增强。所以GNN的大致运算过程可以写为
X l + 1 = a g g r e g a t e ( f ( X l ) , A ) X l + 1 = f ( a g g r e g a t e ( X l , A ) ) X^{l+1}=aggregate(f(X^l),A)\\ X^{l+1}=f(aggregate(X^l,A)) Xl+1=aggregate(f(Xl),A)Xl+1=f(aggregate(Xl,A))
这两种写法并不会影响输入输出的维度,重要的是明白有一次特征聚合,有一次特征增强即可。
pyg的功能比较强大,包括一些utils包下的图级别的工具函数,nn包下常见的GNN层,data包下封装好的图特征对象,loader包下封装好的图batch的loader,甚至还有端到端的GNN模型。自己搭GNN,主要会用到Data、DataLoader、nn、utils等相关工具。因此下面将从搭建一个最简单的GNN,并完成一次输入输出运算为主线,介绍这些工具的用法,以及一些暗箱约定(或者说,坑)
以下的内容主要根据pyg的官方教程梳理,对其中重要的坑会进行强调。
那么,想要使用三方库完成一次GNN运算,我们就必须了解以下的一些内容
接下来从探究以上问题出发,我们力求把pyg的用法讲明白
源教程
我们知道GNN的输入除了各顶点的特征之外,还有邻接矩阵,甚至还会有边的特征。pyg内置了Data对象,用于封装GNN的输入。Data对象中最常使用的几个属性包括以下三个。如果实际科研工作中需要使用更复杂的特征,可以回到上面找源教程。
data.x
: 节点特征矩阵,维度是 [num_nodes, num_node_features]
data.edge_index
: 图连接关系,也就是之前所说的邻接矩阵。只不过这里采用了稀疏格式的输入,维度是 [2, num_edges]
类型是 torch.long
。也就是只存储每条边的出发点和终止点,而不是真正的邻接矩阵(这样的矩阵在顶点多边少的时候,非常占内存)。data.y
: 模型的期望输出。如果是完成节点级别任务的GNN,维度一般为[num_nodes, *]
;如果是完成图级别任务的GNN,维度一般为 [1, *]
源教程
DataSet对象通常被我们用于原始数据读取和加工,将数据转换成DataLoader所能接受的输入。
说的现实点,主要就是把我们的数据源转换成一系列的Data对象。这个对象本身也是pyg封装好的,需要我们削足适履,把我们的转换逻辑填进去。以下直接摘自官方教程
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
# 读取已经转换好格式的数据
self.data, self.slices = torch.load(self.processed_paths[0])
# 未处理好的数据,如果pyg发现文件不存在,会进行下载
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
# 处理好的数据,如果pyg发现文件不存在,会调用self.process()函数
@property
def processed_file_names(self):
return ['data.pt']
# 下载原始数据的函数,如果不需要就直接pass
def download(self):
download_url(url, self.raw_dir)
# 将原始数据进行处理,转换为Data对象,并且保存下来
def process(self):
# 假设这里已经经过了一系列处理,得到了包含Data对象的List
data_list = [...]
# 雷打不动的两句话,处理List并存盘
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
源教程
玩深度学习的同学们都知道,训练模型一般要把多个数据打包成一个mini-batch,再丢给模型训练(原因我就不解释了)。DataLoader就是完成这个工作的。如果你恰好用过pytorch,你肯定也知道pytorch默认的DataLoader会把batch_size
个样本打包成[batch_size, d1, d2, ..., dn]
维度的输入,其中[d1, d2, ..., dn]
是样本本来的特征维度。
而pyg的DataLoader,会把数据打包成[batch_size*num_nodes, num_node_features]
的维度(也就是batch_size
不会单独成一维)。说实话,这一点非常坑(当然,从性能的角度,也可以说“妙”)。pyg官方的解释是“为了增加并行度”,那么,这一步操作是怎么增加并行度的?
稍加思考就可以明白。之前我们就说过,GNN最重要的操作之一,就是进行特征聚合。那进行特征聚合的代码怎么写?我们以求均值为例,最暴力的,当然是对着邻接矩阵,一个一个把邻居的特征加起来再取平均了
X i l + 1 = ( ∑ j a i j x i j ) / n X i = X [ i , : ] x i j = X [ i , j ] X^{l+1}_i = (\sum_j a_{ij}x_{ij})/n\\ X_i = X[i,:]\\ x_{ij}=X[i,j] Xil+1=(j∑aijxij)/nXi=X[i,:]xij=X[i,j]
当然,大家都知道拿循环来算加权和,效率非常低,因此应该用矩阵的形式来表示这一运算。假设我们已经算出了各个节点加权和的系数,形成一个系数矩阵,那上面的循环直接就可以用一个矩阵乘法表示了
X l + 1 = A ⋅ X l A ∈ [ n u m _ n o d e s , n u m _ n o d e s ] X ∈ [ n u m _ n o d e s , n u m _ n o d e _ f e a t u r e s ] X^{l+1}=A\cdot X^l\\ A\in [num\_nodes,num\_nodes]\\ X\in [num\_nodes,num\_node\_features] Xl+1=A⋅XlA∈[num_nodes,num_nodes]X∈[num_nodes,num_node_features]
可以想象,A乘在左边,就是对X做了行变换,也就是对X的每一行进行了加权和。
图神经网络的计算效率是比较低的,多张图之间难以进行并行化。假如num_nodes
不是很大,那进行一次上述的运算,也不会有太大的加速。因此pyg的DadaLoader将数据打包成了[batch_size*num_nodes, num_node_features]
维度,相当于大大提升了参与一次图运算的顶点数,因此可以充分利用向量运算的优势。
mini-batch内打包了节点特征、样本标签、连接关系、batch信息等内容
我们现在以最简单的两个图为例,说明打包后的数据长什么样
接下来,特征、标签都正常拼接;但是节点连接关系会进行一定的运算
data.x
: 将两个图的节点特征直接拼接成[5, num_node_features]
的矩阵data.y
: 将两个图的节点标签直接拼接成[5, *]
的矩阵data.edge_index
: 将两个图混合成一张大图,形成[2, 10]
的矩阵。得到的邻接矩阵大概是[[0,1],[1,0],[0,2],[2,0],[1,2],[2,1],[3,4],[4,3]]
。这里为了看着方便,我把稀疏邻接矩阵转置了一下,实际上它的维度是[2, num_edges]
欸,哪来的节点3和4呢?这就是这一并行化算法巧妙的地方,它将多个图融合成一张大图——其实图的编号没有太大的实际意义,它只是表达哪几个节点需要进行信息交换,需要把X的哪几行进行交换罢了。因此整个运算结果是非常正确的。
因此,我们再回头看一眼,Data对象中,各个元素的维度——知道为什么num_nodes
一定作为第一维了吗?想要最大限度地利用pyg库带来的遍历,就必须削足适履,迎合它的编码方式。在下一篇文章中,我将讲述如何定制DataLoader从而增加一些灵活性。
除了以上所述的那些内容,DataLoader还会打包一个batch信息。这一信息主要是为了从batch中再区分出各个图所用,在进行一些图级别的全局运算,比如softmax,比如global_average,比如global_max,我们肯定希望是在一个图样本中进行(不然在这张用于运算的大图上进行全局计算,有什么实际意义吗)
源教程
接下来我们搭一个最简单的两层GCN网络
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
接下来是训练过程
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
最后是测试过程
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
好吧,这一章很水,毕竟这只是一个最简单的GNN例子,pyg也内置了很多的GNN模型供大家调用。需要注意的是,在刚刚的例程中,其实pyg内置GNN层的输入,并不是data,而是分立的data.x和data.edge_index,很多其它GNN层也是如此。这是因为通常大家先会用基本的GNN层搭建一些小模块,这些模块内可能带着卷积,可能带着池化,而对于卷积层来说,它并不需要知道batch信息。
这一篇文章基本在翻译教程的过程中写完了,加上了自己在构建DataSet对象和DataLoader对象中踩的坑。但是这不是全部,后续,我将从一个交通领域的T-GCN模型出发,讲述如何使用pyg库复现这一时空图神经网络模型。