torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor or tuple of LongTensors
功能:用于输出数组的非零值的索引,即用来定位数组中非零的元素
输入:
注意:
torch.nonzero
也可以通过tensor.nonzero
的方式调用具体用法可见下面的代码案例
一般用法
import torch
import numpy as np
# 从0,1之间随机选20个数
a=torch.from_numpy(np.random.choice(2,20)).reshape(4,5)
b=torch.nonzero(a)
print(a)
print(b)
输出
# 原数组
tensor([[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[1, 0, 1, 0, 1],
[1, 0, 0, 0, 0]], dtype=torch.int32)
# 默认情况下as_tuple为False,二维张量数组
tensor([[1, 0],
[2, 0],
[2, 2],
[2, 4],
[3, 0]])
as_tuple设为True与False的区别
import torch
import numpy as np
# 从0,1之间随机选12个数
a=torch.from_numpy(np.random.choice(2,12)).reshape(2,2,3)
b=torch.nonzero(a,as_tuple=False)
c=torch.nonzero(a,as_tuple=True)
print(a)
print(b)
print(c)
输出
# 原数组
tensor([[[0, 0, 1],
[0, 0, 1]],
[[1, 1, 0],
[1, 0, 0]]], dtype=torch.int32)
# as_tuple设为False时,返回一个二维数组
# 行代表不同的非零元素点,列代表坐标
# 由于是三维数组,所以有三列,每列对应三个维度
tensor([[0, 0, 2],
[0, 1, 2],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0]])
# as_tuple设为False时,返回一个元组
# 这里输入数组是三维,所以元组里面有三个一维数组
# 第一个一维数组代表原数组的第一维度,后面一次类推
# 第一个一维数组的第一个元素代表第一个非零元素的第一维度坐标,后面依次类推
(tensor([0, 0, 1, 1, 1]), tensor([0, 1, 0, 0, 1]), tensor([2, 2, 0, 1, 0]))
常用于特定数组元素的定位操作
import torch
a=torch.arange(20).reshape(4,5)
# 输出数组a中,数值为5的元素索引
b=torch.nonzero(a==5)
print(a)
print(b)
输出
# 原数组
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
# 数值为5的元素索引
tensor([[1, 0]])
torch.nonzero:torch.nonzero — PyTorch 1.13 documentation