RMPI-graph_classifier.py

class GraphClassifier(nn.Module):
    def __init__(self, params, relation2id, rel_vectors):  # in_dim, h_dim, rel_emb_dim, out_dim, num_rels, num_bases):
        super().__init__()

        self.params = params
        self.relation2id = relation2id
        self.relation_list = list(self.relation2id.values())
        self.link_mode = 6
        self.is_big_dataset = False
        self.is_big_dataset = True if self.params.dataset in ['wikidata_small'] else False

        ############################
        self.rel_vectors = rel_vectors
        self.no_jk = self.params.no_jk
        self.target2nei_atten = params.target2nei_atten
        #############################


        self.rel_emb = nn.Embedding(self.params.num_rels, self.params.rel_emb_dim, sparse=False)

        torch.nn.init.normal_(self.rel_emb.weight)

        self.fc_reld1 = nn.ModuleList([nn.Linear(self.params.rel_emb_dim, self.params.rel_emb_dim, bias=True)
                                      for _ in range(6)
                                      ])
        self.fc_reld2 = nn.ModuleList([nn.Linear(self.params.rel_emb_dim, self.params.rel_emb_dim, bias=True)
                                      for _ in range(6)
                                      ])
        self.fc_reld = nn.Linear(self.params.rel_emb_dim, self.params.rel_emb_dim, bias=True)


        self.fc_layer = nn.Linear(self.params.rel_emb_dim, 1)

        if self.params.conc:
            self.conc = nn.Linear(self.params.rel_emb_dim*2, self.params.rel_emb_dim)


        if self.params.gpu >= 0:
            self.device = torch.device('cuda:%d' % self.params.gpu)
            self.rel_vectors = self.rel_vectors.to(device=self.device)
        else:
            self.device = torch.device('cpu')

        #######################
        self.transform1 = nn.Linear(self.rel_vectors.shape[1], self.params.rel_emb_dim)
        self.transform2 = nn.Linear(self.params.rel_emb_dim, self.params.rel_emb_dim)
        #######################

        self.leakyrelu = nn.LeakyReLU(0.2)
        self.drop = torch.nn.Dropout(0.5)
   def rel_aggr(self, graph, u_node, v_node, num_nodes, num_edges, aggr_flag, is_drop):
        u_in_edge = graph.in_edges(u_node, 'all')
        u_out_edge = graph.out_edges(u_node, 'all')
        v_in_edge = graph.in_edges(v_node, 'all')
        v_out_edge = graph.out_edges(v_node, 'all')

        edge_mask = self.drop(torch.ones(num_edges))
        edge_mask = edge_mask.repeat(num_nodes, 1)

        in_edge_out = torch.sparse_coo_tensor(torch.cat((u_in_edge[1].unsqueeze(0), u_in_edge[2].unsqueeze(0)), 0),
                                              torch.ones(len(u_in_edge[2])), size=torch.Size((num_nodes, num_edges)))
        out_edge_out = torch.sparse_coo_tensor(torch.cat((u_out_edge[0].unsqueeze(0), u_out_edge[2].unsqueeze(0)), 0),
                                               torch.ones(len(u_out_edge[2])), size=torch.Size((num_nodes, num_edges)))
        in_edge_in = torch.sparse_coo_tensor(torch.cat((v_in_edge[1].unsqueeze(0), v_in_edge[2].unsqueeze(0)), 0),
                                             torch.ones(len(v_in_edge[2])), size=torch.Size((num_nodes, num_edges)))
        out_edge_in = torch.sparse_coo_tensor(torch.cat((v_out_edge[0].unsqueeze(0), v_out_edge[2].unsqueeze(0)), 0),
                                              torch.ones(len(v_out_edge[2])), size=torch.Size((num_nodes, num_edges)))

        if is_drop:
            in_edge_out = self.sparse_dense_mul(in_edge_out, edge_mask)
            out_edge_out = self.sparse_dense_mul(out_edge_out, edge_mask)
            in_edge_in = self.sparse_dense_mul(in_edge_in, edge_mask)
            out_edge_in = self.sparse_dense_mul(out_edge_in, edge_mask)

        if self.is_big_dataset:  # smaller memory
            in_edge_out = self.sparse_index_select(in_edge_out, u_node).to(device=self.device)
            out_edge_out = self.sparse_index_select(out_edge_out, u_node).to(device=self.device)
            in_edge_in = self.sparse_index_select(in_edge_in, v_node).to(device=self.device)
            out_edge_in = self.sparse_index_select(out_edge_in, v_node).to(device=self.device)
        else:  # faster calculation
            in_edge_out = in_edge_out.to(device=self.device).to_dense()[u_node].to_sparse()
            out_edge_out = out_edge_out.to(device=self.device).to_dense()[u_node].to_sparse()
            in_edge_in = in_edge_in.to(device=self.device).to_dense()[v_node].to_sparse()
            out_edge_in = out_edge_in.to(device=self.device).to_dense()[v_node].to_sparse()


        edge_mode_5 = out_edge_out.mul(in_edge_in)
        edge_mode_6 = in_edge_out.mul(out_edge_in)
        out_edge_out = out_edge_out.sub(edge_mode_5)
        in_edge_in = in_edge_in.sub(edge_mode_5)
        in_edge_out = in_edge_out.sub(edge_mode_6)
        out_edge_in = out_edge_in.sub(edge_mode_6)   

这段代码主要用于图数据的处理和计算。让我逐步解释它的功能:

  1. 首先,从输入的图 graph 中获取了四个稀疏张量 u_in_edgeu_out_edgev_in_edgev_out_edge,这些张量分别表示了节点 u_node 和节点 v_node 的入边和出边。

  2. 接下来,创建了一个掩码矩阵 edge_mask,这个矩阵的目的是用来控制是否保留某些边。掩码矩阵的形状是 (num_nodes, num_edges),并初始化为全1。然后,使用 self.drop 函数对掩码进行了一些操作,以确定哪些边需要保留。

  3. 接着,创建了四个稀疏张量 in_edge_outout_edge_outin_edge_inout_edge_in。这些张量用于表示不同类型的边,并且将它们初始化为稀疏张量,形状也是 (num_nodes, num_edges)。这些张量的值是由之前从图中提取的边信息构建而成的。

  4. 如果 is_drop 为真,则对这四个稀疏张量应用了掩码操作,将不需要的边设置为零。

  5. 根据条件 self.is_big_dataset,选择了不同的方式来处理稀疏张量。如果条件为真,使用 self.sparse_index_select 函数选择了特定节点(u_nodev_node)的相关边,并将其移到指定的设备上(self.device)。如果条件为假,执行了另一种方式的处理,首先将稀疏张量转换为密集张量,选择特定节点的边,然后再将其转换回稀疏张量。

  6. 最后,进行了一系列矩阵操作。edge_mode_5edge_mode_6 是通过对不同的稀疏张量进行逐元素相乘操作得到的。然后,这些操作结果被用于修改 out_edge_outin_edge_inin_edge_outout_edge_in 的值,通过减去对应的 edge_mode 来更新这些稀疏张量的值。

总的来说,这段代码主要用于处理图数据,包括边的掩码、稀疏张量的选择和矩阵操作。这些操作可能与图的结构和计算有关,但具体的目的和效果需要根据代码的上下文和输入参数来理解。

        if aggr_flag == 1:
            edge_connect_l = [in_edge_out, out_edge_out, in_edge_in, out_edge_in, edge_mode_5, edge_mode_6]

            rel_neighbor_embd = sum([torch.sparse.mm(edge_connect_l[i],
                                                     self.fc_reld2[i](self.h1)) for i in range(self.link_mode)])


            return rel_neighbor_embd

        elif aggr_flag == 2:
            edge_connect_l = [in_edge_out, out_edge_out, in_edge_in, out_edge_in, edge_mode_5, edge_mode_6]

            rel_neighbor_embd = sum([torch.sparse.mm(edge_connect_l[i],
                                                     self.fc_reld1[i](self.h0)) for i in
                                     range(self.link_mode)])

            return rel_neighbor_embd

        elif aggr_flag == 0:
            num_target = u_node.shape[0]
            dis_target_edge_ids = self.rel_edge_ids
            self_mask = torch.ones((num_target, num_edges))
            for i in range(num_target):
                self_mask[i][dis_target_edge_ids[i]] = 0
            self_mask = self_mask.to(device=self.device)
            edge_mode_5 = self.sparse_dense_mul(edge_mode_5, self_mask)
            # self_mask = torch.sparse_coo_tensor((rows, dis_target_edge_ids), values, size=torch.Size((num_target, num_edges)))
            # print(edge_mode_5)



            # edge_connect_l = sum(edge_connect_l)
            edge_connect_l = in_edge_out + out_edge_out + in_edge_in + out_edge_in + edge_mode_5 + edge_mode_6


            # neighbor_rel_embeds = self.transform(self.rel_vectors[graph.edata['type']])
            neighbor_rel_embeds = self.transform2(self.transform1(self.rel_vectors[graph.edata['type']]))
            ##########################
            rel_embeds = self.transform2(self.transform1(self.rel_vectors[self.rel_labels]))
            ##########################

            rel_2directed_atten = torch.einsum('bd,nd->bn', [self.fc_reld(rel_embeds), self.fc_reld(neighbor_rel_embeds)])
            rel_2directed_atten = self.leakyrelu(rel_2directed_atten)

            atten = self.sparse_dense_mul(edge_connect_l, rel_2directed_atten).to_dense()
            mask = (atten == 0).bool()
            atten_softmax = torch.nn.Softmax(dim=-1)(atten.masked_fill(mask, -np.inf))
            atten_softmax = torch.where(torch.isnan(atten_softmax), torch.full_like(atten_softmax, 0),
                                            atten_softmax).to_sparse()
            # rel_neighbor_embd = torch.sparse.mm(atten_softmax, neighbor_rel_embeds)
            rel_neighbor_embd = torch.sparse.mm(atten_softmax, self.fc_reld(neighbor_rel_embeds))


            return rel_neighbor_embd

这段代码似乎是根据不同的aggr_flag(聚合标志)执行不同的图数据聚合操作。我将逐步解释每个分支的功能:

  1. 如果aggr_flag等于1,代码执行以下操作

    • 创建一个包含六个稀疏张量的列表 edge_connect_l,这些张量包括之前计算的各种边和模式。
    • 使用列表推导式对 edge_connect_l 中的每个张量执行矩阵乘法操作,将其与一个神经网络层 self.fc_reld2[i](self.h1) 相乘,并将结果相加,得到 rel_neighbor_embd
    • 返回 rel_neighbor_embd 作为结果。
  2. aggr_flag等于2,代码执行以下操作:

    • 创建一个包含六个稀疏张量的列表 edge_connect_l,同样包括之前计算的各种边和模式。
    • 使用列表推导式对 edge_connect_l 中的每个张量执行矩阵乘法操作,将其与一个神经网络层 self.fc_reld1[i](self.h0) 相乘,并将结果相加,得到 rel_neighbor_embd
    • 返回 rel_neighbor_embd 作为结果。
  3. 如果aggr_flag等于0,代码执行以下操作:

    • 计算 num_target,它代表目标节点的数量,并获取名为 dis_target_edge_ids 的与目标节点相关的边的ID。
    • 创建一个名为 self_mask 的矩阵,形状为 (num_target, num_edges),并初始化为全1。
    • 针对每个目标节点,将 self_mask 中与其相关的边的值设置为0,这似乎是用于掩盖目标节点的自连接边。
    • self_mask 转移到指定设备。
    • 使用 self.sparse_dense_mul 函数将 edge_mode_5self_mask 相乘,以消除自连接边。
    • edge_connect_l 执行一系列操作,包括将所有稀疏张量相加,并修改了其中的某些张量。
    • 对节点嵌入进行一些神经网络操作,并计算了一些关于节点之间关系的注意力权重。
    • 最后,根据注意力权重对邻居节点的嵌入进行加权求和,得到 rel_neighbor_embd

请注意,这段代码涉及了许多不同的矩阵操作、神经网络层的应用以及注意力机制的计算。它的具体功能和效果可能需要根据上下文和模型的结构来理解。

你可能感兴趣的:(杂,python,深度学习,pytorch)