torch.sparse稀疏矩阵

前言

这几天接触图神经网络,会发现有些图的边相对结点比较少,那么邻接矩阵很多元素都为0, 显然,我们可以优化存储策略,以节省内存和加快运算。本文udating…



例子

定义稀疏矩阵

import torch

# 邻接矩阵
adj_dense = torch.tensor([[2, 0], [0, 3]])

# 下标 [0, 0], [1, 1]
indices = torch.tensor([[0, 1],
                        [0, 1]])
values = torch.tensor([2, 3])
shape = torch.Size((2, 2))
# 定义一个稀疏矩阵
adj = torch.sparse.FloatTensor(indices, values, shape)

print(adj)
print(adj.to_dense())

w = torch.tensor([[1,2], [3,4]])

spmm = torch.spmm(adj, w)
mm = torch.mm(adj, w)
print('spmm:', spmm)
print('mm:  ', mm)

tensor(indices=tensor([[0, 1],
[0, 1]]),
values=tensor([2, 3]),
size=(2, 2), nnz=2, layout=torch.sparse_coo)
tensor([[2, 0],
[0, 3]])
spmm: tensor([[ 2, 4],
[ 9, 12]])
mm: tensor([[ 2, 4],
[ 9, 12]])


邻接矩阵转换为稀疏矩阵

邻接矩阵在这里也称为稠密矩阵。

import torch
import scipy.sparse as sp


def _convert_sp_mat_to_sp_tensor(X):
    # scipy稠密矩阵转换为scipy的稀疏矩阵方法为scipy.sparse,scipy的稀疏矩阵转为稠密矩阵的方法,直接.todense()
    # pytorch的稀疏矩阵转为稠密矩阵.to_dense(),稠密矩阵转稀疏矩阵torch.sparse.FloatTensor(i, v, coo.shape)
    # 代码更改自 https://blog.csdn.net/qq_29494693/article/details/121324097
    coo = X.tocoo()
    i = torch.LongTensor([coo.row, coo.col])
    v = torch.from_numpy(coo.data).float()
    return torch.sparse.FloatTensor(i, v, coo.shape)


if __name__ == '__main__':
	adj = [[0, 1, 0], [0, 0, 2], [0, 1, 0]]
    adj_sp = sp.csr_matrix(adj)
    print('adj_sp:\n', adj_sp)
    adj_torch_sp = _convert_sp_mat_to_sp_tensor(adj_sp)
    print('adj_torch_sp:\n', adj_torch_sp)
    print('adj_torch_sp.to_dense():\n', adj_torch_sp.to_dense())

adj_sp:
(0, 1) 1
(1, 2) 2
(2, 1) 1
adj_torch_sp:
tensor(indices=tensor([[0, 1, 2],
[1, 2, 1]]),
values=tensor([1., 2., 1.]),
size=(3, 3), nnz=3, layout=torch.sparse_coo)
adj_torch_sp.to_dense():
tensor([[0., 1., 0.],
[0., 0., 2.],
[0., 1., 0.]])



高维稀疏矩阵

import torch


num_edge_type = 2
n_node = 4
index = torch.tensor([[0, 1], [0, 2], [1, 3]])
value = torch.tensor([1.0, 1.0])
shape = torch.Size([num_edge_type, n_node, n_node])

# 创建稀疏矩阵
sp = torch.sparse.FloatTensor(index, value, shape)
print('sp:\n', sp)
# 转化为稠密矩阵
sp_dense = sp.to_dense()
print('sp_dense:\n', sp_dense)
sp:
 tensor(indices=tensor([[0, 1],
                       [0, 2],
                       [1, 3]]),
       values=tensor([1., 1.]),
       size=(2, 4, 4), nnz=2, layout=torch.sparse_coo)
sp_dense:
 tensor([[[0., 1., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],
        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 0.]]])



resize

在喂入模型的时候, 可能需要padding一下,这个时候之前的稀疏矩阵就需要resize一下了。
torch.sparse.tensor.sparse_resize_函数用法如下:

import torch


num_edge_type = 2
n_node = 4
index = torch.tensor([[0, 1, 1], [0, 0, 2], [1, 1, 3]])
value = torch.tensor([1.0, 1.0, 1.0])
shape = torch.Size([num_edge_type, n_node, n_node])

# 创建稀疏矩阵
sp = torch.sparse.FloatTensor(index, value, shape)
print('sp:\n', sp)
sp_dense = sp.to_dense()
print('sp_dense:\n', sp_dense)


print('\n')
new_nodes_size = 6    # 新的图的节点个数 
sp.sparse_resize_(size=(2, new_nodes_size, new_nodes_size), sparse_dim=3, dense_dim=0)
print('sp:\n', sp)
sp_dense = sp.to_dense()
print('sp_dense:\n', sp_dense)
sp:
 tensor(indices=tensor([[0, 1, 1], 
                       [0, 0, 2],  
                       [1, 1, 3]]),
       values=tensor([1., 1., 1.]),
       size=(2, 4, 4), nnz=3, layout=torch.sparse_coo)
sp_dense:
 tensor([[[0., 1., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 1., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 0.]]])


sp:
 tensor(indices=tensor([[0, 1, 1],
                       [0, 0, 2],
                       [1, 1, 3]]),
       values=tensor([1., 1., 1.]),
       size=(2, 6, 6), nnz=3, layout=torch.sparse_coo)
sp_dense:
 tensor([[[0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]],

        [[0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]])

你可能感兴趣的:(编程语言学习,矩阵,pytorch,深度学习,图神经网络)