图(graph)神经网络--GAT网络(pytorch版代码分析)

GAT(Graph Attention Network)

GitHub项目(GAT[keras版]  GAT[pytotch版]  GAT[tensorflow版])

该项目做的任务仍是图中节点分类问题,语料仍是Cora


1.下载代码,并上传到服务器解压

unzip pyGAT-master.zip

2.选择或安装运行该程序需要的环境

pyGAT relies on Python 3.5 and PyTorch 0.4.1 (due to torch.sparse_coo_tensor).

激活环境  source activate pt_env

3.进入pyGAT-master目录,运行:Python main.py

以上操作,运行成功!!!




开始代码解剖

1.超参设置

2.加载数据

idx_features_labels    [0]是节点id    [1-1433]是节点的one-hot特征向量  [1434]是节点的label标签。  这个数据是从文件data/cora/cora.content文件中读出来的。

将刚才加载的idx_features_labels数据,取出features部分,用稀疏矩阵的形式存储;取出labels部分,转换成one-hot多分类向量。

从data/cora/cora.cites里读入数据,构建整个大图的邻接矩阵。

cora.cites里的数据格式如图,点对形式

3.搭建GAT模型

GAT(Graph Attention Network)

GAT整个模型,初始有8个注意力层

GraphAttentionLayer层代码

模型训练,输入数据转换过程,数据形状

你可能感兴趣的:(图(graph)神经网络--GAT网络(pytorch版代码分析))