pyGAT源码阅读

layers.GraphAttentionLayer._prepare_attentional_mechanism_input

Wh.shape
Out[1]: torch.Size([2708, 8])
Wh_repeated_in_chunks.shape
Out[3]: torch.Size([7333264, 8])
N
Out[4]: 2708
N**2
Out[5]: 7333264

Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)

目的其实是把两个拼起来

        # e1 || e1
        # e1 || e2
        # e1 || e3
        # ...
        # e1 || eN
        # e2 || e1
        # e2 || e2
        # e2 || e3
        # ...
        # e2 || eN
        # ...
        # eN || e1
        # eN || e2
        # eN || e3
        # ...
        # eN || eN
all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
# all_combinations_matrix.shape == (N * N, 2 * out_features)

return all_combinations_matrix.view(N, N, 2 * self.out_features)

e是一个 N × N N\times N N×N 的矩阵, 表示节点之间的相似程度. 先通过attention = torch.where(adj > 0, e, zero_vec)操作去掉不联通边, 再通过softmax, dropout等操作, 得到 ∈ [ 0 , 1 ] \in [0, 1] [0,1]attention.

通过h_prime = torch.matmul(attention, Wh)对节点信息进行聚合.

e.shape
Out[8]: torch.Size([2708, 2708])
e.max()
Out[9]: tensor(0.4182, device='cuda:0', grad_fn=<MaxBackward1>)
e.min()
Out[10]: tensor(-0.0771, device='cuda:0', grad_fn=<MinBackward1>)
e.mean()
Out[11]: tensor(0.0167, device='cuda:0', grad_fn=<MeanBackward0>)
Wh = torch.mm(h, self.W)  # h.shape: (N, in_features), Wh.shape: (N, out_features)
a_input = self._prepare_attentional_mechanism_input(Wh)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

zero_vec = -9e15 * torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, Wh)

因为是multi-head attention , 重复做8遍这种操作, 然后横向拼接, 得到 [2708, 64]

x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)
x = F.elu(self.out_att(x, adj))
return F.log_softmax(x, dim=1)

这整个操作是这样的, 相当于有个隐层, 然后有个输出层, 隐层是multi-head attention, 实现 1433 → 8 × 8 1433\rightarrow 8\times 8 14338×8 ; 输出层单独一个attention, 实现 64 → 7 64\rightarrow 7 647, 直接输出logits

我显卡太垃圾, 只能把注意力头调小

--nb_heads 3 --epochs 300

最后测试集效果为0.8260000000000001, 相比GCN 有了很大的提升

你可能感兴趣的:(GNN)