忘了 忘了,以前学的矩阵知识全交给老师了,敲黑板了,矩阵乘法实例讲解

忘了 忘了,以前学的矩阵知识全交给老师了,敲黑板了,矩阵乘法实例讲解_第1张图片

忘了 忘了,以前学的矩阵知识全交给老师了,敲黑板了,矩阵乘法实例讲解_第2张图片

在这个地方整蒙了,W1和W2这俩是矩阵的标记,但是后面只有个Tr-1和Cr-1,我????

实际上,1 x Tr-1这种表示,即可以是向量也可以是矩阵呀,没有问题,往下理解是向量,往上理解就是矩阵,

其次,W1与 做运算是将W1当作Tr-1 × 1的矩阵进行运算,而W3和 运算是将其当成1 × Cr-1进行运算,我们可以看到W1和W3是都是用的R^(?)的形式表示的矩阵,但是运算时候却不一样,敲黑板了!这就是我整蒙圈的地方,你想R^(?)是个欧几里得空间,1×Cr-1和Cr-1 × 1都是R^Cr-1的欧几里得空间,因此R^(?)这个东西我们是根据运算可以实时调整它是1 x c还是c x 1的!

 

接下来就是R^(N x Cr-1 x Tr-1),这个???这咋个运算哦?

敲黑板!这种高维的,我们以低维的眼光去分解他,看成N个Cr-1 x Tr-1的矩阵,那么W3与分别与N个相乘,那么自然就得到了N个1 x Tr-1的矩阵,再写成R就是R^(N x Tr-1)了。再转置一下,右边W3的计算结果就是一个R^(Tr-1 x N)的值

 

好!那么代码里面怎么搞得呢?先上代码:

# defer the shape of params
        self.W_1.shape = (num_of_timesteps, )  #对标初始化的时间长度
        self.W_2.shape = (num_of_features, num_of_timesteps)   #特征数目 x 时间长度
        self.W_3.shape = (num_of_features, )
        self.b_s.shape = (1, num_of_vertices, num_of_vertices)
        self.V_s.shape = (num_of_vertices, num_of_vertices)
        for param in [self.W_1, self.W_2, self.W_3, self.b_s, self.V_s]:
            param._finish_deferred_init()    #完成模块的初始化

        # compute spatial attention scores
        # shape of lhs is (batch_size, V, T)
        lhs = nd.dot(nd.dot(x, self.W_1.data()), self.W_2.data())  #lhs式子左部

        # shape of rhs is (batch_size, T, V)
        #之所以x是四维的是因为批处理的问题,batch_size=10意思是10个样本同时进行计算。x的维度是10 x V x features x T
        rhs = nd.dot(self.W_3.data(), x.transpose((2, 0, 3, 1)))
        
        product = nd.batch_dot(lhs, rhs)

        S = nd.dot(self.V_s.data(),
                   nd.sigmoid(product + self.b_s.data())
                     .transpose((1, 2, 0))).transpose((2, 0, 1))

诶?这里面咋有多了一维?啥情况,原因是我们训练是一个batch一个batch训练的

那么这里面咋没有 .T的转置操作呢?讲解

一开始batch_size x N x C x T,那么你当然可以直接算,然后再来个转置,最终能够满足Tr-1 x N的结果,但如果你先对这个高维的矩阵进行一下transpose的变化,变换成C x batch x T x N这个形式再和W3运算,就可以直接得到T x N就不用做转置了。类似于上面三维的情况,C x batch x T x N做运算的时候看成C x  1的矩阵,里面的1是batch x T x N,然后就成了batch x T x N,因为batch是训练用的,其实就是Tr-1 x N了

 

参考资料:Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting 2020.AAAI

                  https://github.com/wanhuaiyu/ASTGCN

你可能感兴趣的:(GNN学习笔记,python)