目的:运行并粗略看懂Graph attention network的pytorch代码。
代码地址:https://github.com/Diego999/pyGAT
论文地址:This is a pytorch implementation of the Graph Attention Network (GAT) model presented by Veličković et. al (2017, https://arxiv.org/abs/1710.10903). ICLR 2018
目录
一、配置环境与运行
1.1 安装并配置pytorch
1.2 配置requirements.txt
1.3 运行结果
二、代码
2.1 结构
2.2 训练代码
训练前过程
训练
测试
逐个epoch训练
2.3 数据加载与模型结构定义
Graph Attention Network为了避免与GAN弄混,因此缩写为GAT。
CentOS 6.3安装anaconda并配置pytorch与cuda
配置完成后,source activate torch
需要的环境:
numpy==1.15.1
scipy==1.1.0
torch==0.4.1.post2
python train.py 时warning:
train.py:96: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
'loss_train: {:.4f}'.format(loss_train.data[0]),
因为torch版本问题,可能会有报警,但是不影响训练与运行。
pip install -r requirements.txt
报错:
Could not find a version that satisfies the requirement torch==0.4.1.post2 (from -r requirements.txt (line 3)) (from versions: 0.1.2, 0.1.2.post1, 0.3.1, 0.4.0, 0.4.1, 1.0.0, 1.0.1, 1.0.1.post2)
No matching distribution found for torch==0.4.1.post2 (from -r requirements.txt (line 3))
解决,先更新pip,pip install --upgrade pip。再运行pip install -r requirements.txt
依然无法安装。将其他两项安装好,此项影响不大。可以正常运行,先不管。
python train.py
Epoch: 0002 loss_train: 1.9417 acc_train: 0.2143 loss_val: 1.9261 acc_val: 0.4600 time: 1.4492s
Epoch: 0003 loss_train: 1.9287 acc_train: 0.2786 loss_val: 1.9161 acc_val: 0.4867 time: 1.4498s
。。。
Epoch: 0743 loss_train: 0.6635 acc_train: 0.7714 loss_val: 0.6587 acc_val: 0.8167 time: 1.4533s
Epoch: 0744 loss_train: 0.5009 acc_train: 0.8429 loss_val: 0.6582 acc_val: 0.8200 time: 1.4527s
Optimization Finished!
Total time elapsed: 1085.0045s
Loading 643th epoch
Test set results: loss= 0.6640 accuracy= 0.8440
对比于GCN,GAT的运行时间较长且迭代实验较长。
GCN与GAT代码结构非常类似
GAT的训练代码与GCN的训练代码非常类似:Graph Convolution Network图卷积网络(一)训练运行与代码概览
加载参数,生成随机种子,加载数据
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
# Load data
adj, features, labels, idx_train, idx_val, idx_test = load_data()
定义模型与优化器
# Model and optimizer
if args.sparse:
model = SpGAT(nfeat=features.shape[1],
nhid=args.hidden,
nclass=int(labels.max()) + 1,
dropout=args.dropout,
nheads=args.nb_heads,
alpha=args.alpha)
else:
model = GAT(nfeat=features.shape[1],
nhid=args.hidden,
nclass=int(labels.max()) + 1,
dropout=args.dropout,
nheads=args.nb_heads,
alpha=args.alpha)
optimizer = optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
定义每个epoch的训练
所有pytorch的训练基本为这个流程
def train(epoch):
t = time.time()
model.train()
optimizer.zero_grad()
output = model(features, adj)
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
acc_train = accuracy(output[idx_train], labels[idx_train])
loss_train.backward()
optimizer.step()
if not args.fastmode:
# Evaluate validation set performance separately,
# deactivates dropout during validation run.
model.eval()
output = model(features, adj)
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
acc_val = accuracy(output[idx_val], labels[idx_val])
print('Epoch: {:04d}'.format(epoch+1),
'loss_train: {:.4f}'.format(loss_train.data[0]),
'acc_train: {:.4f}'.format(acc_train.data[0]),
'loss_val: {:.4f}'.format(loss_val.data[0]),
'acc_val: {:.4f}'.format(acc_val.data[0]),
'time: {:.4f}s'.format(time.time() - t))
return loss_val.data[0]
def compute_test():
model.eval()
output = model(features, adj)
loss_test = F.nll_loss(output[idx_test], labels[idx_test])
acc_test = accuracy(output[idx_test], labels[idx_test])
print("Test set results:",
"loss= {:.4f}".format(loss_test.data[0]),
"accuracy= {:.4f}".format(acc_test.data[0]))
for epoch in range(args.epochs):
loss_values.append(train(epoch))
torch.save(model.state_dict(), '{}.pkl'.format(epoch))
if loss_values[-1] < best:
best = loss_values[-1]
best_epoch = epoch
bad_counter = 0
else:
bad_counter += 1
if bad_counter == args.patience:
break
files = glob.glob('*.pkl')
for file in files:
epoch_nb = int(file.split('.')[0])
if epoch_nb < best_epoch:
os.remove(file)
这里可以注意一个结构,训练停止的标志。就是如果新的loss比上一次大,则加一个bad_counter,然后如果bad_counter达到了patience,则停止训练。
这部分较为重要,在后面详细描述。