torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor or tuple of LongTensors
功能:用于输出数组的非零值的索引,即用来定位数组中非零的元素
输入:
input
:输入的数组as_tuple
:如果设为False
,则返回一个二维张量,其中每一行都是非零值的索引,如果输入的数组有n维,则输出的张量维度大小为z×n,其中z为input
非零元素的总数;如果设为True
,则返回一个由一维张量组成的元组,如果输入数组为n维,则有n个一维张量,每个一维张量对应非零元素特定维度的索引(第一个张量数组储存的是所有非零元素第一维度的索引),并且每个张量里面有z个数,其中z为输入数组非零元素的个数。注意:
torch.nonzero
也可以通过tensor.nonzero
的方式调用具体用法可见下面的代码案例
一般用法
import torch
import numpy as np
# 从0,1之间随机选20个数
a=torch.from_numpy(np.random.choice(2,20)).reshape(4,5)
b=torch.nonzero(a)
print(a)
print(b)
输出
# 原数组
tensor([[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 1, 0, 1],
[1, 0, 0, 0, 0]], dtype=torch.int32)
# 默认情况下as_tuple为False,二维张量数组
tensor([[1, 0],
[2, 0],
[2, 2],
[2, 4],
[3, 0]])
as_tuple设为True与False的区别
import torch
import numpy as np
# 从0,1之间随机选12个数
a=torch.from_numpy(np.random.choice(2,12)).reshape(2,2,3)
b=torch.nonzero(a,as_tuple=False)
c=torch.nonzero(a,as_tuple=True)
print(a)
print(b)
print(c)
输出
# 原数组
tensor([[[0, 0, 1],
[0, 0, 1]],
[[1, 1, 0],
[1, 0, 0]]], dtype=torch.int32)
# as_tuple设为False时,返回一个二维数组
# 行代表不同的非零元素点,列代表坐标
# 由于是三维数组,所以有三列,每列对应三个维度
tensor([[0, 0, 2],
[0, 1, 2],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0]])
# as_tuple设为False时,返回一个元组
# 这里输入数组是三维,所以元组里面有三个一维数组
# 第一个一维数组代表原数组的第一维度,后面一次类推
# 第一个一维数组的第一个元素代表第一个非零元素的第一维度坐标,后面依次类推
(tensor([0, 0, 1, 1, 1]), tensor([0, 1, 0, 0, 1]), tensor([2, 2, 0, 1, 0]))
常用于特定数组元素的定位操作
import torch
a=torch.arange(20).reshape(4,5)
# 输出数组a中,数值为5的元素索引
b=torch.nonzero(a==5)
print(a)
print(b)
输出
# 原数组
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
# 数值为5的元素索引
tensor([[1, 0]])
torch.nonzero:https://pytorch.org/docs/stable/generated/torch.nonzero.html?highlight=torch%20nonzero#torch.nonzero