pytorch tensor 获取指定值的indices

import torch

x  = torch.randn(1, 3, 6, 6)
y = torch.zeros(x.shape).to(x.device)
y[x >= 0.5] = 1
z = (y == 1).nonzero(as_tuple = False) 
print(z)

# z is the indices you need.

你可能感兴趣的:(pytorch,人工智能,机器学习)