GCN论文阅读与代码梳理(1)——AGCRN

GCN论文阅读与代码梳理(1)——AGCRN_第1张图片

传统基于GCN的流量预测需要通过距离或相似度定义邻接矩阵,预先定义的图不能包含关于空间依赖性的完整信息,与预测任务没有直接关系,这可能导致相当大的偏差。此外,如果没有适当的知识,这些方法无法适用于其他领域,使得现有的基于GCN的模型失效。因此,提出了DAGG进行图的自适应学习。

通过两个自适应模块来增强GCN,以完成交通预测任务:

  • Node Adaptive Parameter Learning (NAPL) module(节点自适应参数学习方法):

D^{-0.5}AD^{-0.5} = softmax(ReLU(E_{A}E_{A}^{T}))

  • Data Adaptive Graph Generation (DAGG) module(数据自适应图生成):

Z = I_{N}+softmax(ReLU(E_{A}E_{A}^{T}))X\theta

在AGCRN中,采用切比雪夫多项式逼近拉普拉斯矩阵,以实现图卷积功能。切比雪夫不等式:

T_k(x) = 2xT_{k-1}(x)-T_{k-2}(x)

在消融实验中,发现移除单位矩阵极大的降低了预测性能,这说明在预测中突出self-information十分重要。而二阶与一阶的性能差别不大。

相应代码如下:

# class AVWGCN
 def forward(self, x, node_embeddings):   
     # x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N]
     # output shape [B, N, C]
     node_num = node_embeddings.shape[0]
     supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1)  # 矩阵相乘
     support_set = [torch.eye(node_num).to(supports.device), supports]
     # default cheb_k = 3
     for k in range(2, self.cheb_k):  # cheb_k为切比雪夫多项式阶数
         support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
     supports = torch.stack(support_set, dim=0)
     weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool)  # N, cheb_k, dim_in, dim_out
     bias = torch.matmul(node_embeddings, self.bias_pool)  # N, dim_out
     x_g = torch.einsum("knm,bmc->bknc", supports, x)  # B, cheb_k, N, dim_in
     x_g = x_g.permute(0, 2, 1, 3)  # B, N, cheb_k, dim_in
     x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias  # b, N, dim_out
     return x_gconv
  • 提出了Adaptive Graph Convolutional Recurrent Network,包含NAPL-GCN, DAGG, and Gated Recurrent Units (GRU)

    \\ \widetilde{A} = softmax(ReLU(EE^T))\\ z_t = \sigma(\widetilde{A}[X_{:,t},h_{t-1}]EW_{z}+Eb_{z})\\ r_t = \sigma(\widetilde{A}[X_{:,t},h_{t-1}]EW_{r}+Eb_{r})\\ \widehat{h_t} = tanh(\widetilde{A}[X_{:,t},r*h_{t-1}]EW_{\widehat{h}}+Eb_{\widehat{h}})\\ h_t = z*h_{t-1}+(1-z)*\widehat{h_t}

相应代码如下:

 # class AGCRNCell
 def forward(self, x, state, node_embeddings):   
     #x: B, num_nodes, input_dim
     #state: B, num_nodes, hidden_dim
     state = state.to(x.device)
     input_and_state = torch.cat((x, state), dim=-1)
     z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))   # self.gate即AVWGCN模块
     z, r = torch.split(z_r, self.hidden_dim, dim=-1)
     candidate = torch.cat((x, z*state), dim=-1)
     hc = torch.tanh(self.update(candidate, node_embeddings))
     h = r*state + (1-r)*hc
     return h

GRU相应代码如下:

 # class AGCRN
 def forward(self, source, targets=None, teacher_forcing_ratio=0.5):  
     # source: B, T_1, N, D
     # target: B, T_2, N, D
     # supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1)
 ​
     init_state = self.encoder.init_hidden(source.shape[0])  # encoder为基于GCN的GRU,迭代预测
     output, _ = self.encoder(source, init_state, self.node_embeddings)  # B, T, N, hidden
     output = output[:, -1:, :, :]  # B, 1, N, hidden
 ​
     # CNN based predictor
     output = self.end_conv((output))  # B, T*C, N, 1
     output = output.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node)
     output = output.permute(0, 1, 3, 2)  # B, T, N, C
 ​
     return output

你可能感兴趣的:(GCN,深度学习,pytorch,人工智能)