class GraphClassifier(nn.Module):
def __init__(self, params, relation2id, rel_vectors): # in_dim, h_dim, rel_emb_dim, out_dim, num_rels, num_bases):
super().__init__()
self.params = params
self.relation2id = relation2id
self.relation_list = list(self.relation2id.values())
self.link_mode = 6
self.is_big_dataset = False
self.is_big_dataset = True if self.params.dataset in ['wikidata_small'] else False
############################
self.rel_vectors = rel_vectors
self.no_jk = self.params.no_jk
self.target2nei_atten = params.target2nei_atten
#############################
self.rel_emb = nn.Embedding(self.params.num_rels, self.params.rel_emb_dim, sparse=False)
torch.nn.init.normal_(self.rel_emb.weight)
self.fc_reld1 = nn.ModuleList([nn.Linear(self.params.rel_emb_dim, self.params.rel_emb_dim, bias=True)
for _ in range(6)
])
self.fc_reld2 = nn.ModuleList([nn.Linear(self.params.rel_emb_dim, self.params.rel_emb_dim, bias=True)
for _ in range(6)
])
self.fc_reld = nn.Linear(self.params.rel_emb_dim, self.params.rel_emb_dim, bias=True)
self.fc_layer = nn.Linear(self.params.rel_emb_dim, 1)
if self.params.conc:
self.conc = nn.Linear(self.params.rel_emb_dim*2, self.params.rel_emb_dim)
if self.params.gpu >= 0:
self.device = torch.device('cuda:%d' % self.params.gpu)
self.rel_vectors = self.rel_vectors.to(device=self.device)
else:
self.device = torch.device('cpu')
#######################
self.transform1 = nn.Linear(self.rel_vectors.shape[1], self.params.rel_emb_dim)
self.transform2 = nn.Linear(self.params.rel_emb_dim, self.params.rel_emb_dim)
#######################
self.leakyrelu = nn.LeakyReLU(0.2)
self.drop = torch.nn.Dropout(0.5)
def rel_aggr(self, graph, u_node, v_node, num_nodes, num_edges, aggr_flag, is_drop):
u_in_edge = graph.in_edges(u_node, 'all')
u_out_edge = graph.out_edges(u_node, 'all')
v_in_edge = graph.in_edges(v_node, 'all')
v_out_edge = graph.out_edges(v_node, 'all')
edge_mask = self.drop(torch.ones(num_edges))
edge_mask = edge_mask.repeat(num_nodes, 1)
in_edge_out = torch.sparse_coo_tensor(torch.cat((u_in_edge[1].unsqueeze(0), u_in_edge[2].unsqueeze(0)), 0),
torch.ones(len(u_in_edge[2])), size=torch.Size((num_nodes, num_edges)))
out_edge_out = torch.sparse_coo_tensor(torch.cat((u_out_edge[0].unsqueeze(0), u_out_edge[2].unsqueeze(0)), 0),
torch.ones(len(u_out_edge[2])), size=torch.Size((num_nodes, num_edges)))
in_edge_in = torch.sparse_coo_tensor(torch.cat((v_in_edge[1].unsqueeze(0), v_in_edge[2].unsqueeze(0)), 0),
torch.ones(len(v_in_edge[2])), size=torch.Size((num_nodes, num_edges)))
out_edge_in = torch.sparse_coo_tensor(torch.cat((v_out_edge[0].unsqueeze(0), v_out_edge[2].unsqueeze(0)), 0),
torch.ones(len(v_out_edge[2])), size=torch.Size((num_nodes, num_edges)))
if is_drop:
in_edge_out = self.sparse_dense_mul(in_edge_out, edge_mask)
out_edge_out = self.sparse_dense_mul(out_edge_out, edge_mask)
in_edge_in = self.sparse_dense_mul(in_edge_in, edge_mask)
out_edge_in = self.sparse_dense_mul(out_edge_in, edge_mask)
if self.is_big_dataset: # smaller memory
in_edge_out = self.sparse_index_select(in_edge_out, u_node).to(device=self.device)
out_edge_out = self.sparse_index_select(out_edge_out, u_node).to(device=self.device)
in_edge_in = self.sparse_index_select(in_edge_in, v_node).to(device=self.device)
out_edge_in = self.sparse_index_select(out_edge_in, v_node).to(device=self.device)
else: # faster calculation
in_edge_out = in_edge_out.to(device=self.device).to_dense()[u_node].to_sparse()
out_edge_out = out_edge_out.to(device=self.device).to_dense()[u_node].to_sparse()
in_edge_in = in_edge_in.to(device=self.device).to_dense()[v_node].to_sparse()
out_edge_in = out_edge_in.to(device=self.device).to_dense()[v_node].to_sparse()
edge_mode_5 = out_edge_out.mul(in_edge_in)
edge_mode_6 = in_edge_out.mul(out_edge_in)
out_edge_out = out_edge_out.sub(edge_mode_5)
in_edge_in = in_edge_in.sub(edge_mode_5)
in_edge_out = in_edge_out.sub(edge_mode_6)
out_edge_in = out_edge_in.sub(edge_mode_6)
这段代码主要用于图数据的处理和计算。让我逐步解释它的功能:
首先,从输入的图 graph
中获取了四个稀疏张量 u_in_edge
、u_out_edge
、v_in_edge
和 v_out_edge
,这些张量分别表示了节点 u_node
和节点 v_node
的入边和出边。
接下来,创建了一个掩码矩阵 edge_mask
,这个矩阵的目的是用来控制是否保留某些边。掩码矩阵的形状是 (num_nodes, num_edges)
,并初始化为全1。然后,使用 self.drop
函数对掩码进行了一些操作,以确定哪些边需要保留。
接着,创建了四个稀疏张量 in_edge_out
、out_edge_out
、in_edge_in
和 out_edge_in
。这些张量用于表示不同类型的边,并且将它们初始化为稀疏张量,形状也是 (num_nodes, num_edges)
。这些张量的值是由之前从图中提取的边信息构建而成的。
如果 is_drop
为真,则对这四个稀疏张量应用了掩码操作,将不需要的边设置为零。
根据条件 self.is_big_dataset
,选择了不同的方式来处理稀疏张量。如果条件为真,使用 self.sparse_index_select
函数选择了特定节点(u_node
和 v_node
)的相关边,并将其移到指定的设备上(self.device
)。如果条件为假,执行了另一种方式的处理,首先将稀疏张量转换为密集张量,选择特定节点的边,然后再将其转换回稀疏张量。
最后,进行了一系列矩阵操作。edge_mode_5
和 edge_mode_6
是通过对不同的稀疏张量进行逐元素相乘操作得到的。然后,这些操作结果被用于修改 out_edge_out
、in_edge_in
、in_edge_out
和 out_edge_in
的值,通过减去对应的 edge_mode
来更新这些稀疏张量的值。
总的来说,这段代码主要用于处理图数据,包括边的掩码、稀疏张量的选择和矩阵操作。这些操作可能与图的结构和计算有关,但具体的目的和效果需要根据代码的上下文和输入参数来理解。
if aggr_flag == 1:
edge_connect_l = [in_edge_out, out_edge_out, in_edge_in, out_edge_in, edge_mode_5, edge_mode_6]
rel_neighbor_embd = sum([torch.sparse.mm(edge_connect_l[i],
self.fc_reld2[i](self.h1)) for i in range(self.link_mode)])
return rel_neighbor_embd
elif aggr_flag == 2:
edge_connect_l = [in_edge_out, out_edge_out, in_edge_in, out_edge_in, edge_mode_5, edge_mode_6]
rel_neighbor_embd = sum([torch.sparse.mm(edge_connect_l[i],
self.fc_reld1[i](self.h0)) for i in
range(self.link_mode)])
return rel_neighbor_embd
elif aggr_flag == 0:
num_target = u_node.shape[0]
dis_target_edge_ids = self.rel_edge_ids
self_mask = torch.ones((num_target, num_edges))
for i in range(num_target):
self_mask[i][dis_target_edge_ids[i]] = 0
self_mask = self_mask.to(device=self.device)
edge_mode_5 = self.sparse_dense_mul(edge_mode_5, self_mask)
# self_mask = torch.sparse_coo_tensor((rows, dis_target_edge_ids), values, size=torch.Size((num_target, num_edges)))
# print(edge_mode_5)
# edge_connect_l = sum(edge_connect_l)
edge_connect_l = in_edge_out + out_edge_out + in_edge_in + out_edge_in + edge_mode_5 + edge_mode_6
# neighbor_rel_embeds = self.transform(self.rel_vectors[graph.edata['type']])
neighbor_rel_embeds = self.transform2(self.transform1(self.rel_vectors[graph.edata['type']]))
##########################
rel_embeds = self.transform2(self.transform1(self.rel_vectors[self.rel_labels]))
##########################
rel_2directed_atten = torch.einsum('bd,nd->bn', [self.fc_reld(rel_embeds), self.fc_reld(neighbor_rel_embeds)])
rel_2directed_atten = self.leakyrelu(rel_2directed_atten)
atten = self.sparse_dense_mul(edge_connect_l, rel_2directed_atten).to_dense()
mask = (atten == 0).bool()
atten_softmax = torch.nn.Softmax(dim=-1)(atten.masked_fill(mask, -np.inf))
atten_softmax = torch.where(torch.isnan(atten_softmax), torch.full_like(atten_softmax, 0),
atten_softmax).to_sparse()
# rel_neighbor_embd = torch.sparse.mm(atten_softmax, neighbor_rel_embeds)
rel_neighbor_embd = torch.sparse.mm(atten_softmax, self.fc_reld(neighbor_rel_embeds))
return rel_neighbor_embd
这段代码似乎是根据不同的aggr_flag
(聚合标志)执行不同的图数据聚合操作。我将逐步解释每个分支的功能:
如果aggr_flag
等于1,代码执行以下操作:
edge_connect_l
,这些张量包括之前计算的各种边和模式。edge_connect_l
中的每个张量执行矩阵乘法操作,将其与一个神经网络层 self.fc_reld2[i](self.h1)
相乘,并将结果相加,得到 rel_neighbor_embd
。rel_neighbor_embd
作为结果。如果aggr_flag
等于2,代码执行以下操作:
edge_connect_l
,同样包括之前计算的各种边和模式。edge_connect_l
中的每个张量执行矩阵乘法操作,将其与一个神经网络层 self.fc_reld1[i](self.h0)
相乘,并将结果相加,得到 rel_neighbor_embd
。rel_neighbor_embd
作为结果。如果aggr_flag
等于0,代码执行以下操作:
num_target
,它代表目标节点的数量,并获取名为 dis_target_edge_ids
的与目标节点相关的边的ID。self_mask
的矩阵,形状为 (num_target, num_edges)
,并初始化为全1。self_mask
中与其相关的边的值设置为0,这似乎是用于掩盖目标节点的自连接边。self_mask
转移到指定设备。self.sparse_dense_mul
函数将 edge_mode_5
与 self_mask
相乘,以消除自连接边。edge_connect_l
执行一系列操作,包括将所有稀疏张量相加,并修改了其中的某些张量。rel_neighbor_embd
。请注意,这段代码涉及了许多不同的矩阵操作、神经网络层的应用以及注意力机制的计算。它的具体功能和效果可能需要根据上下文和模型的结构来理解。