【pytorch】 获得索引的方法

                                  运算      函数

                                  大于     torch.gt

                                  小于     torch.lt

                                  等于     torch.eq

                                  非零     torch.nonzero

                                  非         torch.ne

import torch

x = torch.arange(5)   
print(x)
mask = torch.gt(x,1)   # 大于
print(mask)
print(x[mask])

x = torch.arange(5)   
print(x)
mask = torch.lt(x,3)   # 小于
print(mask)
print(x[mask])

x = torch.arange(5)   
print(x)
mask = torch.eq(x,3)   # 等于
print(mask)
print(x[mask])

x = torch.Tensor([1,2,1,0,0])
mask = torch.ne(x,1)   # 非,一个数
print(mask)
print(x[mask])

a = torch.Tensor([[0.6, 0.0, 0.0, 0.0],[0.0, 0.4, 0.0, 0.0],[0.0, 0.0, 1.2, 0.0],[0.0, 0.0, 0.0,-0.4]])
mask = torch.nonzero(a)   # 非零
print(mask)
print(torch.numel(mask))
print(torch.numel(a))
# print(a[mask])
print(torch.numel(mask)/torch.numel(a))

 

你可能感兴趣的:(pytorch)