以下的代码表示,构造的矩阵在(0,0),(1,1),(2,2)
这几个位置都有非0取值
import torch
i = torch.tensor([[0, 1, 2],[0, 1, 2]])
以下的代码表示,矩阵三个位置存在值,分别取1, 3, 9
v = torch.tensor([1, 3, 9])
以下的代码表示,构造出一个形似:
(0, 0) 1
(1, 1) 3
(2, 2) 9
的稀疏矩阵,这个矩阵的shape为(3, 3)
>>> torch.sparse.FloatTensor(i, v, (3, 3))
tensor(indices=tensor([[0, 1, 2],
[0, 1, 2]]),
values=tensor([1, 3, 9]),
size=(3, 3), nnz=3, layout=torch.sparse_coo)
将稀疏矩阵转成稠密矩阵,可以更加清晰地观察到矩阵的一个形态
>>> sparse_matrix.to_dense()
tensor([[1, 0, 0],
[0, 3, 0],
[0, 0, 9]])