图算法源码学习stellargraph-1.graphsage

为了进一步熟悉基本的图方法,简单阅读github上的一个优秀源码 https://github.com/stellargraph/stellargraph ,做一些简单记录。

简介

此项目基于tf2.0实现了常用的图算法如 GraphSage、GCN、Node2Vec等

项目结构

├── AUTHORS
├── CHANGELOG.md
├── CONTRIBUTING.md
├── CONTRIBUTORS
├── LICENSE
├── MANIFEST.in
├── README.md
├── RELEASE_PROCEDURE.md
├── codecov.yml
├── demos        # 例子,建议从这学起
│   ├── README.md
│   ├── basics
│   ├── calibration
│   ├── community_detection
│   ├── connector
│   ├── embeddings
│   ├── ensembles
│   ├── graph-classification
│   ├── interpretability
│   ├── link-prediction
│   ├── node-classification
│   └── use-cases
├── docker # docker环境
│   ├── stellargraph
│   ├── stellargraph-ci-runner
│   ├── stellargraph-neo4j
│   └── stellargraph-treon
├── docker-compose.yml
├── docs # 文档
│   ├── Makefile
│   ├── README.md -> ../README.md
│   ├── api.txt
│   ├── conf.py
│   ├── hinsage.txt
│   ├── images
│   ├── index.txt
│   └── requirements.txt
├── meta.yaml
├── pytest.ini
├── requirements.txt
├── scripts
│   ├── README.md
│   ├── format_notebooks.py
│   ├── test_demos.py
│   └── whitespace.sh
├── setup.py
├── stellar-graph-banner.png
├── stellargraph   #核心代码
│   ├── __init__.py 
│   ├── calibration.py
│   ├── connector
│   ├── core
│   ├── data
│   ├── datasets      #读取并构建图 这个需要看
│   ├── ensemble.py
│   ├── globalvar.py
│   ├── interpretability
│   ├── layer   #具体策略实现部分
│   ├── losses.py
│   ├── mapper
│   ├── random.py
│   ├── utils
│   └── version.py
└── tests
    ├── __init__.py
    ├── core
    ├── data
    ├── datasets
    ├── interpretability
    ├── layer
    ├── mapper
    ├── reproducibility
    ├── resources
    ├── test_calibration.py
    ├── test_ensemble.py
    ├── test_losses.py
    ├── test_random.py
    └── test_utils

39 directories, 39 files

1. 数据读取

以directed-graphsage-on-cora-example为例,

dataset = datasets.Cora()
display(HTML(dataset.description))
G, node_subjects = dataset.load(directed=True)

第三行为读取dataframe格式的数据,构建图的入口代码为 datasets/datasets.py 调用顺序为

load 加载->_load_cora_or_citeseer 加载cora数据> cls 构建图(dataset.py 77行,加载graph)

读取的数据格式为:
node_data: 节点特征


image.png

edge边信息


image.png

输入的格式为:

 graph = cls({"paper": features}, {"cites": edgelist})# node的特征和边信息详见(datasets.py 84行)

2. 图构建

以构建有向图为例,77行的cls为StellarDiGraph类

StellarDiGraph(graph.py)-> 返回networkx结构的数据和node节点的subject数据

3. 节点采样构建feature

batch_size = 50 #每次50个node训练
in_samples = [5, 2]  # 入度 第一层采样5个 第二层采样2个node
out_samples = [5, 2] # 出度
generator = DirectedGraphSAGENodeGenerator(G, batch_size, in_samples, out_samples) # 实例化生成batch和shuffle的方法

参数说明

G (StellarDiGraph): The machine-learning ready graph.
batch_size (int): Size of batch to return.
in_samples (list): The number of in-node samples per layer (hop) to take.
out_samples (list): The number of out-node samples per layer (hop) to take.
seed (int): [Optional] Random seed for the node sampler.

路径: stellargraph/mapper/sampled_node_generators.py

DirectedGraphSAGENodeGenerator-> DirectedBreadthFirstNeighbours(explorer.py) -DirectedBreadthFirstNeighbours(采样舒适化)-> sample_features 采样->返回feature

DirectedBreadthFirstNeighbours

根据输入的in out samples 进行采样, 并将节点打平, 特征拼接成feature
(这个是在NodeSequencebatch调用的)

4. 训练

/Users/clz/PycharmProjects/stellargraph/stellargraph/mapper/sampled_node_generators.py
103行 flow函数

#  batch和shuffle构建
train_gen = generator.flow(train_subjects.index, train_targets, shuffle=True)
# sage构建
graphsage_model = DirectedGraphSAGE(
    layer_sizes=[32, 32], generator=generator, bias=False, dropout=0.5,
)
# 网络结构构建
x_inp, x_out = graphsage_model.in_out_tensors()
prediction = layers.Dense(units=train_targets.shape[1], activation="softmax")(x_out)
model = Model(inputs=x_inp, outputs=prediction)
model.compile(
    optimizer=optimizers.Adam(lr=0.005),
    loss=losses.categorical_crossentropy,
    metrics=["acc"],
)

流程

flow -> NodeSequence(提供batch和shuffle方法) ->DirectedGraphSAGE->in_out_tensors()

NodeSequence

提供batch和shuffle方法 返回batch_feats, batch_targets
batch_feats: 被采样node的特征
batch_targets:中心节点的target(label)

graphsage_model.in_out_tensors()

in_tensor是每个采样节点的特征使用keras input进行了转换作为输入
out_tensor是in_tensor调用了apply_layer方法增加了MeanAggregator和drop层

MeanAggregator

graphsage.py 307行

你可能感兴趣的:(图算法源码学习stellargraph-1.graphsage)