CS224W作业 class GAT(MessagePassing)

2023年CS22Wassignment中的所有colab答案以及注释已经上传到github:https://github.com/yuyu990116/CS224W-assignment
CS224W课程地址:http://web.stanford.edu/class/cs224w/

class GAT(MessagePassing):

    def __init__(self, in_channels, out_channels, heads = 2,
                 negative_slope = 0.2, dropout = 0., **kwargs):
        super(GAT, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.lin_l=Linear(in_channels,heads*out_channels) #W_l
        self.lin_r = self.lin_l  #W_r
        self.att_l = Parameter(torch.Tensor(1, heads, out_channels)) #a_l
        self.att_r = Parameter(torch.Tensor(1, heads, out_channels)) #a_r
        self.reset_parameters()
	def reset_parameters(self):
        nn.init.xavier_uniform_(self.lin_l.weight)
        nn.init.xavier_uniform_(self.lin_r.weight)
        nn.init.xavier_uniform_(self.att_l)
        nn.init.xavier_uniform_(self.att_r)
    def forward(self, x, edge_index, size = None):
        H, C = self.heads, self.out_channels
        x_l=self.lin_l(x)
        x_r=self.lin_r(x)
        x_l=x_l.view(-1,H,C)
        x_r=x_r.view(-1,H,C)
        alpha_l = (x_l * self.att_l).sum(axis=1)  #α_l
        alpha_r = (x_r * self.att_r).sum(axis=1)  #α_r
        out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r),size=size)
        out = out.view(-1, H * C)
        return out
    def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
        alpha = alpha_i + alpha_j 
        alpha = F.leaky_relu(alpha,self.negative_slope)
        #leakeyrelu的负斜率
        alpha = softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training).unsqueeze(1)  #[N,1,C]
        out = x_j * alpha 
        return out
  def aggregate(self, inputs, index, dim_size = None):
          out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='sum')

你可能感兴趣的:(CS224W,机器学习,人工智能,深度学习,神经网络)