方式一: 直接取布尔值
输入:
target = torch.Tensor([1,0,0,2,0,0,3])
mask = (target > 0)
masked_target = target[mask]
print(target)
print(mask)
print(masked_target)
输入:
target = torch.Tensor([1,0,0,2,0,0,3])
mask = target.ge(0)
masked_target = torch.masked_select(target, mask)
print(target)
print(mask)
print(masked_target)
输出:
tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([ True, False, False, True, False, False, True])
tensor([1., 2., 3.])
方式二:自己设置ByteTensor作为掩码
输入:
target = torch.Tensor([1,0,0,2,0,0,3])
mask = torch.ByteTensor([1,0,0,1,0,0,0])
masked_target = torch.masked_select(target, mask)
print(target)
print(masked_target)
输出:
tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([1., 2.])