【Torch API】torch.nonzero用法详解

torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor or tuple of LongTensors

功能:用于输出数组的非零值的索引,即用来定位数组中非零的元素

输入:

  • input:输入的数组
  • as_tuple:如果设为False,则返回一个二维张量,其中每一行都是非零值的索引,如果输入的数组有n维,则输出的张量维度大小为z×n,其中z为input非零元素的总数;如果设为True,则返回一个由一维张量组成的元组,如果输入数组为n维,则有n个一维张量,每个一维张量对应非零元素特定维度的索引(第一个张量数组储存的是所有非零元素第一维度的索引),并且每个张量里面有z个数,其中z为输入数组非零元素的个数。

注意:

  • torch.nonzero也可以通过tensor.nonzero的方式调用

具体用法可见下面的代码案例

代码案例

一般用法

import torch
import numpy as np
# 从0,1之间随机选20个数
a=torch.from_numpy(np.random.choice(2,20)).reshape(4,5)
b=torch.nonzero(a)
print(a)
print(b)

输出

# 原数组
tensor([[0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 1, 0, 1],
        [1, 0, 0, 0, 0]], dtype=torch.int32)
# 默认情况下as_tuple为False,二维张量数组
tensor([[1, 0],
        [2, 0],
        [2, 2],
        [2, 4],
        [3, 0]])

as_tuple设为True与False的区别

import torch
import numpy as np
# 从0,1之间随机选12个数
a=torch.from_numpy(np.random.choice(2,12)).reshape(2,2,3)
b=torch.nonzero(a,as_tuple=False)
c=torch.nonzero(a,as_tuple=True)
print(a)
print(b)
print(c)

输出

# 原数组
tensor([[[0, 0, 1],
         [0, 0, 1]],

        [[1, 1, 0],
         [1, 0, 0]]], dtype=torch.int32)
# as_tuple设为False时,返回一个二维数组
# 行代表不同的非零元素点,列代表坐标
# 由于是三维数组,所以有三列,每列对应三个维度
tensor([[0, 0, 2],
        [0, 1, 2],
        [1, 0, 0],
        [1, 0, 1],
        [1, 1, 0]])
# as_tuple设为False时,返回一个元组
# 这里输入数组是三维,所以元组里面有三个一维数组
# 第一个一维数组代表原数组的第一维度,后面一次类推
# 第一个一维数组的第一个元素代表第一个非零元素的第一维度坐标,后面依次类推
(tensor([0, 0, 1, 1, 1]), tensor([0, 1, 0, 0, 1]), tensor([2, 2, 0, 1, 0]))

常用于特定数组元素的定位操作

import torch
a=torch.arange(20).reshape(4,5)
# 输出数组a中,数值为5的元素索引
b=torch.nonzero(a==5)
print(a)
print(b)

输出

# 原数组
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]])
# 数值为5的元素索引
tensor([[1, 0]])

官方文档

torch.nonzero:torch.nonzero — PyTorch 1.13 documentation

你可能感兴趣的:(语法,深度学习,人工智能,python)