GAT中的如何实现稀疏注意力

之前一直没看GAT的代码(https://github.com/PetarV-/GAT),不知道稀疏矩阵下如何实现注意力的,今天看到,恍然大悟,记录于此

首先,由于稀疏矩阵参与运算时其中的参数不能自动更新(pytorchz中暂时没有其反向传播函数),所以GAT自己写了稀疏矩阵(计算完注意力后的邻接矩阵)与稠密矩阵(特征)的乘法

class SpecialSpmmFunction(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""
    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]
        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
    def forward(self, indices, values, shape, b):
        return SpecialSpmmFunction.apply(indices, values, shape, b)

在需要就算注意力的时候,就可以通过以下方式

class SpGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, dropout, alpha, concat=True):
        super(SpGraphAttentionLayer, self).__init__()
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_dim, out_dim)),requires_grad=True)
        nn.init.xavier_normal_(self.W.data, gain=1.414)
                
        self.a = nn.Parameter(torch.zeros(size=(1, 2*out_dim)),requires_grad=True)
        nn.init.xavier_normal_(self.a.data, gain=1.414)

        self.dropout = nn.Dropout(dropout)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.special_spmm = SpecialSpmm()

    def forward(self, input, adj):
        dv = 'cuda' if input.is_cuda else 'cpu'
        N = input.size()[0]#节点数量2708
        edge_index = adj.nonzero().t()
        
         #先对所有特征进行一次线性变换
        input=torch.mm(input.self.W)

        # 连接节点与其所有邻居的表示
        hidden = torch.cat((input[edge_index[0, :], :], input[edge_index[1, :], :]), dim=1).t()
        
        #通过向量a得到分数
        edge_value = torch.exp(-self.leakyrelu(self.a.mm(hidden).squeeze()))#注意力中的分子
        e_rowsum = self.special_spmm(edge_index, edge_value, torch.Size([N, N]), torch.ones(size=(N,1), device=dv))

        edge_value= self.dropout(edge_value)
        
        #这也是一个技巧,乘完特征再去除,跟算完注意力再去乘特征是一个道理
        h_prime = self.special_spmm(edge_index, edge_value, torch.Size([N, N]), input)
        h_prime = h_prime.div(e_rowsum)
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

 

你可能感兴趣的:(应用)