图中每一个节点都由 d d d维实数特征向量表示(比如可以作为节点的嵌入编码), n n n个节点的特征向量可以写成 M n × d M_{n\times d} Mn×d的矩阵形式。
经过图注意力层以后,输出每一个节点的 d ′ d' d′维新特征并形成矩阵 M n × d ′ ′ M'_{n\times d'} Mn×d′′. 有 W d ′ × d W^{d'\times d} Wd′×d作为变换矩阵作用于每个节点的特征向量 h i ( d × 1 ) h_i^{(d\times 1)} hi(d×1).
为了保留图的结构信息,GAT的一种计算方式只考虑节点的邻接节点并且认为该节点的每个邻节点对其影响力可以用不同的权重(即注意力系数, attention coefficient)表示,以下用一阶邻节点举例(另一种则是考虑全部的顶点):
假设节点 i i i有一个邻节点 j j j, 经过线性变换以后分别是 W h i Wh_i Whi和 W h j Wh_j Whj。再假设有一个映射 a : R d ′ × R d ′ → R a: \R^{d'}\times \R^{d'}\to \R a:Rd′×Rd′→R, 那么邻节点对该节点的注意力系数是:
e i j = a ( W h i , W h j ) e_{ij}=a(Wh_i,Wh_j) eij=a(Whi,Whj)
GAT的具体实现方式是 a ∈ R 1 × 2 d ′ a\in \R^{1\times 2d'} a∈R1×2d′, ( W h i , W h j ) = [ W h i ∣ ∣ W h j ] (Wh_i,Wh_j) = [Wh_i||Wh_j] (Whi,Whj)=[Whi∣∣Whj](把两个向量合并在一起,形成一个 ( 2 d ′ , 1 ) (2d',1) (2d′,1)维的向量),那么 e i j = a T [ W h i ∣ ∣ W h j ] e_{ij}=a^\mathsf{T}[Wh_i||Wh_j] eij=aT[Whi∣∣Whj].
当节点有多个邻节点时,为了避免某个注意力系数的值远大于其他值不便于训练,需要normalization。同时为了泛化模型的拟合能力,对线性变化后的值可以加入非线性激活函数, 最终得到的注意力系数:
α i j = exp ( LeakyRelu ( e i j ) ) ) ∑ k ∈ N i exp ( LeakyRelu ( e i k ) \alpha_{ij}=\frac{\exp(\operatorname{LeakyRelu}(e_{ij})))}{\sum_{k\in\mathcal{N_i} }\exp(\operatorname{LeakyRelu}(e_{ik})} αij=∑k∈Niexp(LeakyRelu(eik)exp(LeakyRelu(eij)))
其中 e i ∈ N i e_i\in \mathcal{N_i} ei∈Ni.
得出节点及其邻节点的注意力系数以后,就可以用于结合 W W W更好地更新 h ′ h' h′了,论文中使用的聚合函数:
h i ′ = σ ( ∑ j ∈ N i α i j W h j ) h'_i = \sigma(\sum_{j\in\mathcal{N_i} }\alpha_{ij}Wh_j) hi′=σ(j∈Ni∑αijWhj)
为了提高聚合器的表现,论文中采用了multi-head attention, 即使用 k k k个独立的注意力机制(采用不同的 a a a和 W W W),然后将得到的结果再次拼接——
h i ′ = ∣∣ k = 1 k σ ( ∑ j ∈ N i α i j ( k ) W ( k ) h j ) h_i'=\operatorname{||}_{k=1}^k \sigma(\sum_{j\in\mathcal{N_i} }\alpha^{(k)}_{ij}W^{(k)}h_j) hi′=∣∣k=1kσ(j∈Ni∑αij(k)W(k)hj)
这会导致 h i ′ h_i' hi′有更高的维度 ( 1 , k d ′ ) (1,kd') (1,kd′),所以只可以做中间层而不可以做输出层。
所以对于输出层,一种聚合方式是将各注意力机制的 h ′ h' h′平均
Output= h ′ = σ ( 1 k ∑ i = 1 k ∑ j ∈ N i α i j ( k ) W ( k ) h j ) \text{Output=}h'=\sigma(\frac{1}{k}\sum_{i=1}^{k}\sum_{j\in\mathcal{N_i} }\alpha^{(k)}_{ij}W^{(k)}h_j) Output=h′=σ(k1i=1∑kj∈Ni∑αij(k)W(k)hj)
代码地址(PyTorch版本):https://github.com/Diego999/pyGAT
图注意力层(即上文原理中提到的注意力机制)的功能是接受由各节点特征向量组成的特征矩阵 H n × d H_{n\times d} Hn×d, 输出新的特征矩阵 H n × d ′ H_{n\times d'} Hn×d′.
一共有两组参数, W W W和 a a a,需要训练。其中 a a a适用于所有的特征向量对。
self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
def forward(self, h, adj):
# 首先对节点的本身特征向量进行线性变换
Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
a_input = self._prepare_attentional_mechanism_input(Wh)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) # 计算未normalized的注意力系数
zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec) # 只计算邻接节点
attention = F.softmax(attention, dim=1)
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, Wh)
# 输出层的self.concat为False, 不进行非线性变化
if self.concat:
return F.elu(h_prime)
else:
return h_prime
def _prepare_attentional_mechanism_input(self, Wh):
N = Wh.size()[0] # number of nodes
Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0) # 对每一个特征向量重复N次
Wh_repeated_alternating = Wh.repeat(N, 1) # 将特征矩阵重复N次
# 下面得到每个节点和其他所有节点组合并拼接而成的特征向量
all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
# all_combinations_matrix.shape == (N * N, 2 * out_features)
return all_combinations_matrix.view(N, N, 2 * self.out_features)
整个GAT的框架就非常直观了,在输入层添加dropout防止过拟合,只有中间层使用了拼接法。
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
def forward(self, x, adj):
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)
x = F.elu(self.out_att(x, adj))
return F.log_softmax(x, dim=1)