目录
1. 重置label取值范围
2. 根据多个labels tensor从整体label数据中提取特定数据。
3. 构建Geometric GATConv和GCNConv的 edge_index
4. 为GCNConv从全部edge index抽取指定的batch edge index
4.1 extract batch edge index for full nodes
4.2 extract batch edge index for batch nodes, reset in the range of [0, batch_size].
5. 将edge index 二维tensor 向量转换为 tensor matrix格式
6. 归一化tensor matrix
7. 快速计算tensor每两个元素相减,结果output tensor matrix
8. 将偏差矩阵转换为邻接矩阵,bias_mx -> adj_mx
problem: otherwise occurs IndexError: target out of bounds
# reset labels value range, otherwise occurs IndexError: target out of bounds
uni_set = torch.unique(labels)
to_set = torch.tensor(list(range(len(uni_set))))
labels_reset = labels.clone().detach()
for from_val, to_val in zip(uni_set, to_set):
labels_reset = torch.where(labels_reset == from_val, to_val, labels_reset)
label_mask = (labels == label) # numpy array, (100,), ([True, False, True, True])
label_indices = np.where(label_mask)[0] # 同一标签索引, label_index, (3, ) array([0, 2, 3], dtype=int64)
negative_indices = np.where(np.logical_not(label_mask))[0] # (97, ), 其他标签索引,作为负样本 ndarray
# anchor_pos_list = list(combinations(label_indices, 2)) # 2个元素的标签索引组合, list: 3, [(23, 66), (23, 79), (66, 79)]
extract_index_data = edge_index_mx[0: label_indices]
因为torch geometric 即PyG的edge_index数据shape是二维tensor,shape=[2, n].
# relations_ids = ['entity', 'userid', 'word'],分别读取这三个文件
def sparse_trans(datapath = None):
relation = sparse.load_npz(datapath) # (4762, 4762)
all_edge_index = torch.tensor([], dtype=int)
for node in range(relation.shape[0]):
neighbor = torch.IntTensor(relation[node].toarray()).squeeze() # IntTensor是torch定义的7中cpu tensor类型之一;
# squeeze对数据维度进行压缩,删除所有为1的维度
# del self_loop in advance
neighbor[node] = 0 # 对角线元素置0
neighbor_idx = neighbor.nonzero() # 返回非零元素的索引, size: (43, 1)
neighbor_sum = neighbor_idx.size(0) # 表示非零元素数据量,43
loop = torch.tensor(node).repeat(neighbor_sum, 1) # repeat表示按列重复node的次数
edge_index_i_j = torch.cat((loop, neighbor_idx), dim=1).t() # cat表示按dim=1按列拼接;t表示对二维矩阵进行转置, node -> neighbor
self_loop = torch.tensor([[node], [node]])
all_edge_index = torch.cat((all_edge_index, edge_index_i_j, self_loop), dim=1)
del neighbor, neighbor_idx, loop, self_loop, edge_index_i_j
return all_edge_index ## 返回二维矩阵,最后一维是node。 node -> nonzero neighbors
因为GCNConv需要执行卷积操作convolution,index out of the size of batch, 就会报错!
def extract_batch_edge_idx(batch_nodes, edge_index):
edge_index_0 = edge_index[0, :]
batch_nodes = torch.Tensor(batch_nodes)
edge_index_0_bool = [element in batch_nodes for element in edge_index_0]
batch_edge_index = edge_index[0:, edge_index_0_bool]
return batch_edge_index.type(torch.long)
# extract batch edge index for batch nodes
def extract_batch_edge_idx(self, batch_nodes, edge_index):
edge_index_0 = edge_index[0, :]
edge_index_0_bool = [element in batch_nodes for element in edge_index_0]
batch_edge_index = edge_index[0:, edge_index_0_bool]
batch_edge_index_1 = batch_edge_index[1, :]
batch_edge_index_1_bool = [element in batch_nodes for element in batch_edge_index_1]
batch_edge_index = batch_edge_index[0:, batch_edge_index_1_bool]
uni_set = torch.unique(batch_edge_index)
to_set = torch.tensor(list(range(len(uni_set))))
labels_reset = batch_edge_index.clone().detach()
for from_val, to_val in zip(uni_set, to_set):
labels_reset = torch.where(labels_reset == from_val, to_val, labels_reset)
return labels_reset.type(torch.long)
def extract_batch_edge_idx(batch_nodes, edge_index):
extract_edge_index = torch.Tensor()
for i in batch_nodes:
extract_edge_i = torch.Tensor()
# extract 1-st row index and 2-nd row index
edge_index_bool_0 = edge_index[0, :]
edge_index_bool_0 = (edge_index_bool_0 == i)
if edge_index_bool_0 is None:
continue
bool_indices_0 = np.where(edge_index_bool_0)[0]
# extract data
edge_index_0 = edge_index[0:, bool_indices_0]
for j in batch_nodes:
edge_index_bool_1 = edge_index_0[1, :]
edge_index_bool_1 = (edge_index_bool_1 == j)
if edge_index_bool_1 is None:
continue
bool_indices_1 = np.where(edge_index_bool_1)[0]
edge_index_1 = edge_index_0[0:, bool_indices_1]
extract_edge_i = torch.cat((extract_edge_i, edge_index_1), dim=1)
extract_edge_index = torch.cat((extract_edge_index, extract_edge_i), dim=1)
# reset index value in a specific range
uni_set = torch.unique(extract_edge_index)
to_set = torch.tensor(list(range(len(uni_set))))
labels_reset = extract_edge_index.clone().detach()
for from_val, to_val in zip(uni_set, to_set):
labels_reset = torch.where(labels_reset == from_val, to_val, labels_reset)
return labels_reset.type(torch.long)
# # numpy version: 将二维矩阵list 转换成adj matrix list
# def relations_to_adj(filtered_multi_r_data, nb_nodes=None):
# relations_mx_list = []
# for r_data in filtered_multi_r_data:
# data = np.ones(r_data.shape[1])
# relation_mx = sp.coo_matrix((data, (r_data[0], r_data[1])), shape=(nb_nodes, nb_nodes), dtype=int)
# relations_mx_list.append(torch.tensor(relation_mx.todense()))
# return relations_mx_list
# # tensor version: 将二维矩阵list 转换成adj matrix list
def relations_to_adj(filtered_multi_r_data, nb_nodes=None):
relations_mx_list = []
for r_data in filtered_multi_r_data:
data = torch.ones(r_data.shape[1])
relation_mx = torch.sparse_coo_tensor(indices=r_data, values=data, size=[nb_nodes, nb_nodes], dtype=torch.int32)
relations_mx_list.append(relation_mx.to_dense())
return relations_mx_list
# # numpy version: 归一化矩阵
# def normalize_adj(adj): # tensor, (4286,4286)
# """Symmetrically normalize adjacency matrix."""
# adj = np.array(adj.cpu())
# mu = np.mean(adj, axis=1)
# std = np.std(adj, axis=1)
# return torch.from_numpy((adj-mu)/std)
# tensor version: 归一化矩阵
def normalize_adj(adj): # tensor, (4286,4286)
"""Symmetrically normalize adjacency matrix."""
mu = torch.mean(adj, dim=0)
std = torch.std(adj, dim=0)
norm_adj = torch.sub(adj, mu)/std
return norm_adj
# torch 转换为 array
vectors = torch.squeeze(vectors.t()) # (1,100) -> (100,)
# 计算每两个元素相减的绝对值并形成矩阵
diff_matrix = torch.abs(vectors.unsqueeze(0) - vectors.unsqueeze(1))
import torch
def bias_to_adj(bias):
# 使用1和0重建邻接矩阵
adj = torch.where(bias < 0, 0, 1)
return adj
# 示例使用
bias = torch.tensor([[-1e9, -1e9, -1e9],
[-1e9, -1e9, -1e9],
[-1e9, -1e9, 0.0]])
adj = bias_to_adj(bias)
print("邻接矩阵:")
print(adj)