找出tensor中非零的元素的索引。返回一个包含输入 input
中非零元素索引的张量.输出张量中的每行包含 input
中非零元素的索引。
torch.nonzero(input, *, out=None, as_tuple=False)
参数 | 含义 |
input | 输入的必须是tensor |
out | 输出z×n,n代表输入数据的维度,z是总共非0元素的个数 |
as_tuple |
|
栗子:
import torch
label = torch.tensor([[1,0,0],
[1,0,1]])
print(label.nonzero())
输出:
tensor([[0, 0],
[1, 0],
[1, 2]])
返回的结果就是非零元素的索引,其中[0,0]对应了第一行第一列的1,[1,0]对应了第二行第一列的1,[1,2]对应了第二行第三列的1。
延伸:
import torch
label = torch.tensor([[1,0,0],
[3,0,1]])
print((label==1).nonzero())
输出:
tensor([[0, 0],
[1, 2]])
import torch
label = torch.tensor([[1,0,0],
[3,0,1]])
print((label>1).nonzero())
输出:
tensor([[1, 0]])
栗子2:
torch.nonzero(torch.tensor([1, 1, 1, 0, 1]))
#tensor([[ 0],
# [ 1],
# [ 2],
# [ 4]])
torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
[0.0, 0.4, 0.0, 0.0],
[0.0, 0.0, 1.2, 0.0],
[0.0, 0.0, 0.0,-0.4]]))
#tensor([[ 0, 0],
# [ 1, 1],
# [ 2, 2],
# [ 3, 3]])
torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
#(tensor([0, 1, 2, 4]),)
torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
[0.0, 0.4, 0.0, 0.0],
[0.0, 0.0, 1.2, 0.0],
[0.0, 0.0, 0.0,-0.4]]), as_tuple=True)
#(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
#组合起来就是(0,0)(1,1)(2,2)(3,3)
torch.nonzero(torch.tensor(5), as_tuple=True)
#(tensor([0]),)
参考:
- https://blog.csdn.net/qq_36076233/article/details/106642384
- https://blog.csdn.net/qq_36530992/article/details/102836509