【pytorch】nonzero

【pytorch】nonzero

torch.nonzero(input, *, out=None, as_tuple=False)

  • input:输入的必须是tensor
  • out:输出 z × n z\times n z×n n n n代表输入数据的维度, z z z是总共非0元素的个数
  • as_tuple:
  1. if false:输出的每一行为非零元素的索引
  2. if true:输出是每一个维度都有一个一维的张量
torch.nonzero(torch.tensor([1, 1, 1, 0, 1]))
#tensor([[ 0],
#        [ 1],
#        [ 2],
#        [ 4]])

torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
                            [0.0, 0.4, 0.0, 0.0],
                            [0.0, 0.0, 1.2, 0.0],
                            [0.0, 0.0, 0.0,-0.4]]))
#tensor([[ 0,  0],
#        [ 1,  1],
#        [ 2,  2],
#        [ 3,  3]])

torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
#(tensor([0, 1, 2, 4]),)

torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0],
                            [0.0, 0.4, 0.0, 0.0],
                            [0.0, 0.0, 1.2, 0.0],
                            [0.0, 0.0, 0.0,-0.4]]), as_tuple=True)
#(tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]))
#组合起来就是(0,0)(1,1)(2,2)(3,3)

torch.nonzero(torch.tensor(5), as_tuple=True)
#(tensor([0]),)

你可能感兴趣的:(pytorch)