2023年CS22Wassignment中的所有colab答案以及注释已经上传到github:https://github.com/yuyu990116/CS224W-assignment
CS224W课程地址:http://web.stanford.edu/class/cs224w/
class GAT(MessagePassing):
def __init__(self, in_channels, out_channels, heads = 2,
negative_slope = 0.2, dropout = 0., **kwargs):
super(GAT, self).__init__(node_dim=0, **kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.negative_slope = negative_slope
self.dropout = dropout
self.lin_l=Linear(in_channels,heads*out_channels) #W_l
self.lin_r = self.lin_l #W_r
self.att_l = Parameter(torch.Tensor(1, heads, out_channels)) #a_l
self.att_r = Parameter(torch.Tensor(1, heads, out_channels)) #a_r
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.lin_l.weight)
nn.init.xavier_uniform_(self.lin_r.weight)
nn.init.xavier_uniform_(self.att_l)
nn.init.xavier_uniform_(self.att_r)
def forward(self, x, edge_index, size = None):
H, C = self.heads, self.out_channels
x_l=self.lin_l(x)
x_r=self.lin_r(x)
x_l=x_l.view(-1,H,C)
x_r=x_r.view(-1,H,C)
alpha_l = (x_l * self.att_l).sum(axis=1) #α_l
alpha_r = (x_r * self.att_r).sum(axis=1) #α_r
out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r),size=size)
out = out.view(-1, H * C)
return out
def message(self, x_j, alpha_j, alpha_i, index, ptr, size_i):
alpha = alpha_i + alpha_j
alpha = F.leaky_relu(alpha,self.negative_slope)
#leakeyrelu的负斜率
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training).unsqueeze(1) #[N,1,C]
out = x_j * alpha
return out
def aggregate(self, inputs, index, dim_size = None):
out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce='sum')