原理
实际就是通过邻接矩阵来指导模型聚合一阶邻域、二阶邻域等的特征信息,来作为中心点的输出,聚合的方法有多种,至于一阶邻域、二阶邻域等,可以通过不断叠加特征邻域聚合的网络层,来实现向更高阶的扩充。
代码实现:
通过一维卷积增加非线性,赋予不同权重,然后通过矩阵加矩阵的转置,来对attention的权重进行学习,通过-1e9(1-邻接矩阵),以及softmax来过滤掉不与当前顶点相连的节点。通过叠加GAT网络,来对高阶节点信息进行加权。。。
import numpy as np
import tensorflow as tf
from utils import layers
from models.base_gattn import BaseGAttN
class GAT(BaseGAttN):
conv1d = tf.layers.conv1d
def attn_head(seq, out_sz, bias_mat, activation, in_drop=0.0, coef_drop=0.0, residual=False):
with tf.name_scope('my_attn'):
if in_drop != 0.0:
seq = tf.nn.dropout(seq, 1.0 - in_drop)
seq_fts = tf.layers.conv1d(seq, out_sz, 1, use_bias=False)
# simplest self-attention possible
f_1 = tf.layers.conv1d(seq_fts, 1, 1)
f_2 = tf.layers.conv1d(seq_fts, 1, 1)
logits = f_1 + tf.transpose(f_2, [0, 2, 1])
coefs = tf.nn.softmax(tf.nn.leaky_relu(logits) + bias_mat)
if coef_drop != 0.0:
coefs = tf.nn.dropout(coefs, 1.0 - coef_drop)
if in_drop != 0.0:
seq_fts = tf.nn.dropout(seq_fts, 1.0 - in_drop)
vals = tf.matmul(coefs, seq_fts)
ret = tf.contrib.layers.bias_add(vals)
# residual connection
if residual:
if seq.shape[-1] != ret.shape[-1]:
ret = ret + conv1d(seq, ret.shape[-1], 1) # activation
else:
ret = ret + seq
return activation(ret) # activation
def inference(inputs, nb_classes, nb_nodes, training, attn_drop, ffd_drop,
bias_mat, hid_units, n_heads, activation=tf.nn.elu, residual=False):
attns = []
for _ in range(n_heads[0]):
attns.append(layers.attn_head(inputs, bias_mat=bias_mat,
out_sz=hid_units[0], activation=activation,
in_drop=ffd_drop, coef_drop=attn_drop, residual=False))
h_1 = tf.concat(attns, axis=-1)
for i in range(1, len(hid_units)):
h_old = h_1
attns = []
for _ in range(n_heads[i]):
attns.append(layers.attn_head(h_1, bias_mat=bias_mat,
out_sz=hid_units[i], activation=activation,
in_drop=ffd_drop, coef_drop=attn_drop, residual=residual))
h_1 = tf.concat(attns, axis=-1)
out = []
for i in range(n_heads[-1]):
out.append(layers.attn_head(h_1, bias_mat=bias_mat,
out_sz=nb_classes, activation=lambda x: x,
in_drop=ffd_drop, coef_drop=attn_drop, residual=False))
logits = tf.add_n(out) / n_heads[-1]
return logits