Pytorch:torch.nonzero()函数

1 torch.nonzero()

找出tensor中非零的元素的索引返回一个包含输入 input 中非零元素索引的张量.输出张量中的每行包含 input 中非零元素的索引。

torch.nonzero(input, *, out=None, as_tuple=False)
参数 含义
input 输入的必须是tensor
out 输出z×n,n代表输入数据的维度,z是总共非0元素的个数
as_tuple
  • if false:输出的每一行为非零元素的索引
  • if true:输出是每一个维度都有一个一维的张量

栗子:

import torch

label = torch.tensor([[1,0,0],
                      [1,0,1]])
print(label.nonzero())

输出:
tensor([[0, 0],
        [1, 0],
        [1, 2]])

返回的结果就是非零元素的索引,其中[0,0]对应了第一行第一列的1,[1,0]对应了第二行第一列的1,[1,2]对应了第二行第三列的1。

延伸:

  • 有时我们只想得到一种元素对应的索引,比如我们只想要1对应的索引:
import torch

label = torch.tensor([[1,0,0],
                      [3,0,1]])
print((label==1).nonzero())

输出:
tensor([[0, 0],
        [1, 2]])
  • 或者,我们想要一定条件下的元素的索引,比如大于1的元素的索引:
import torch

label = torch.tensor([[1,0,0],
                      [3,0,1]])
print((label>1).nonzero())

输出:
tensor([[1, 0]])

栗子2:

torch.nonzero(torch.tensor([1, 1, 1, 0, 1]))
#tensor([[ 0],
#        [ 1],
#        [ 2],
#        [ 4]])

torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
                            [0.0, 0.4, 0.0, 0.0],
                            [0.0, 0.0, 1.2, 0.0],
                            [0.0, 0.0, 0.0,-0.4]]))
#tensor([[ 0,  0],
#        [ 1,  1],
#        [ 2,  2],
#        [ 3,  3]])

torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
#(tensor([0, 1, 2, 4]),)

torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
                            [0.0, 0.4, 0.0, 0.0],
                            [0.0, 0.0, 1.2, 0.0],
                            [0.0, 0.0, 0.0,-0.4]]), as_tuple=True)
#(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
#组合起来就是(0,0)(1,1)(2,2)(3,3)

torch.nonzero(torch.tensor(5), as_tuple=True)
#(tensor([0]),)

参考:

  • https://blog.csdn.net/qq_36076233/article/details/106642384
  • https://blog.csdn.net/qq_36530992/article/details/102836509

 

你可能感兴趣的:(修仙之路:pytorch篇)