pytorch中几种tensor掩码的获取方法(含代码)

方式一: 直接取布尔值

输入:

target = torch.Tensor([1,0,0,2,0,0,3])
mask = (target > 0)
masked_target = target[mask]

print(target)
print(mask)
print(masked_target)

输入:

target = torch.Tensor([1,0,0,2,0,0,3])
mask = target.ge(0)
masked_target = torch.masked_select(target, mask)

print(target)
print(mask)
print(masked_target)

输出:

tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([ True, False, False,  True, False, False,  True])
tensor([1., 2., 3.])

 

方式二:自己设置ByteTensor作为掩码

输入:

target = torch.Tensor([1,0,0,2,0,0,3])
mask = torch.ByteTensor([1,0,0,1,0,0,0])
masked_target = torch.masked_select(target, mask)

print(target)
print(masked_target)

输出:

tensor([1., 0., 0., 2., 0., 0., 3.])
tensor([1., 2.])

你可能感兴趣的:(Python学习,pytorch)