tensorflow实现GNN的max aggregator

笔者想要实现GNN中的max aggregator,要求如下:
输入:1. 邻接矩阵 A ∈ { 0 , 1 } N × N A\in\{0,1\}^{N\times N} A{0,1}N×N,2. 各点特征矩阵 X ∈ R N × C X\in R^{N\times C} XRN×C
输出:每个点的邻居(点 i i i和点 j j j是邻居 ⇔ A [ i , j ] = 1 \Leftrightarrow A[i,j]=1 A[i,j]=1)特征值的最大值 M ∈ R N × C M\in R^{N\times C} MRN×C
其中点数 N N N未知,特征维数 C C C可知。

笔者首先实现了一个简易版本,将 X X X堆积 N N N个,对应于邻接矩阵展开后的 N N N行,点乘提取相应特征后取各点对应的最大值,代码如下:

output_shape = X.get_shape()
node_num = tf.shape(X)[0]

flat_A = tf.reshape(A,[-1,1])
tiled_X = tf.tile(X,[node_num,1],name='tiled_flat_X')
flat_X_dot_A = tf.reshape(tiled_X*flat_A - 1e4*(1-flat_A),[node_num,node_num,-1])
output_X = tf.reduce_max(flat_X_dot_A,axis=1,keepdims=False)

output_X.set_shape(output_shape)
return output_X

上诉代码中tile命令强制使用了 O ( N 3 ) O(N^3) O(N3)的空间,占用极大显存,而且没有tf的优化空间。

将提取邻居特征的步骤从点乘换成gather,配合while_loop也可实现功能,而最大空间复杂度降为 O ( ∣ E ∣ ∣ V ∣ ) O(|E| |V|) O(EV),实测tf优化后约为 O ( N 2 ) O(N^2) O(N2),代码如下:

def _maximum_neighborhood(self,index,A,X,out):
    with tf.name_scope(self.name_scope):
        neigh = tf.boolean_mask(X,A[index])
        max_neigh = tf.reduce_max(neigh,keepdims=True,axis=0)
        out = tf.concat([out,max_neigh],axis=0)
    return out
def __call__(self,A,X):
    '''
    input arguments:
        A is the graph adjacency matrix of type tf.Tensor and of shape [N,N]
        X is the node attributes matrix of type tf.Tensor and of shape [N,C]
        , where N is the number of nodes and C is the channel number of node attributes
    output arguments:
        aggregated new node attributes X' of type tf.Tensor and of shape [N,C]
    '''
    with tf.name_scope(self.name_scope):
        output_shape = X.get_shape()
        node_num = tf.shape(X)[0]
        output_dim = int(output_shape[-1])

        output_X = tf.zeros([0,output_dim])
        _,_,_,output_X = tf.while_loop(lambda index,A,X,out: index<node_num,\
                      lambda index,A,X,out: [index+1,A,X,self._maximum_neighborhood(index,A,X,out)],\
                      loop_vars = [tf.zeros([],tf.int32),A,X,output_X],\
                      shape_invariants = [tf.TensorShape([]),A.get_shape(),X.get_shape(),tf.TensorShape([None,output_dim])])
                      
        output_X.set_shape(output_shape)
        return output_X

你可能感兴趣的:(Tensorflow,GNN)