快速入门 TFGNN
TFGNN是基于 TensorFlow 的 GNN 库,它同时实现了 MessagePassing 和 GraphNets 框架,这意味着您可以轻松地在框架中设置上下文(全局)值。
开始前的一些有用链接:图神经网络简介(GNN 初学者最好的介绍博客)、TFGNN 发表的论文(CoRR 2022)。
基本数据结构:GraphTensor
构建 GNN 训练管道的第一步是构建数据集。TFGNN的基本数据结构是tfgnn。GraphTensor和node-sets/edge-sets/context,每个集合都有自己的“特征”。
详细介绍文档:https /github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/graph_tensor.md
输入构建流水线
在建模之前,我们首先需要逐步构建我们的训练集。https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/input_pipeline.md
第一步:描述图模式:https /github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/schema.md
在这一步中,我们定义节点集/边集特征和上下文特征;以下是一些注意事项:
tfgnn 文档中提到“标量特征不需要维度识别”。但是实践表明,如果要在建模中直接处理那些标量特征,可以将它们的shape设置为’{dim {size: 1}}',并将数据类型定义为FLOAT。
TFGNN 模型可以在“有向”边的两个方向上传播消息,默认情况下,TFGNN 在消息传递阶段采用有向边。因此,无需在数据准备步骤中指定“无向”边。
有多种图任务类型:单图分类/回归(给一堆小图对它们进行分类),单节点分类/回归(从大图中采样子图)等。
Step2:数据样本准备&采样:https /github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/data_prep.md
in-python编码的步骤:
使用 GraphTensor.from_pieces() 函数为 GraphTensor 类型构造一个 eager 实例。
使用 tfgnn.write_example(graph) 和 writer.write(example.SerializeToString()) 序列化内存实例并将其写入磁盘上的文件。
Step3:读取文件到Dataset:https /github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/input_pipeline.md
要读取一个 GraphTensor(复合张量),我们需要根据定义的 Spec来解析serializable :tfgnn.parse_single_example/parse_example();
建模管道(使用 Keras API):
通常,我们构建2个Keras模型用于预处理和实际建模:1个预处理模型+1个主训练模型;请注意,这两个模型都是在标量样本模式中定义的……
流水线总结:读取数据集->解析为张量->构建预处理模型->构建训练模型->拟合训练模型->导出服务
预处理模型定义:https /github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/input_pipeline.md
预处理模型的动机是将原始特征从离散特征处理为输入向量,例如将 INT 特征映射到 8 长向量;或将一些功能连接在一起。预处理函数的返回值应该是节点/边的起始隐藏状态。
脚步:
对于预处理函数,最好使用 tf.keras.layers 的内置函数。
TFGNN 将一批图样本合并到一个具有单独组件(没有连接的子图)的大图中,然后应用更新函数,将批次作为标量样本。这样做的原因是为了使基于批处理的计算变得可行,因为每个样本可能具有不同的大小,这与传统的表格数据集有很大不同(在图张量上表示’ sizes '= [ total_size,],其中total_size = batch_size *(set_size) *频道大小)。
主要模型定义:
https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/gnn_modeling.md
在 TFGNN 中构建 GNN 层的方式有以下三种:
使用 GATv2、GCN 等原生支持层的 XXXGraphUpdate 函数。
2.从TensorFlow- Keras scratch 自定义tf.keras.layers.Layer 。
祝你好运!