pytorc torch.uint8与torch.long/ torch. float

版本: pytorch1.0
目的: tensor中的index与mask


例子

下面针对torch.longtorch.uint8数据类型在index/mask 中的不同作用进行分析

t = torch.rand(42)
"""
tensor([[0.5492, 0.2083],
        [0.3635, 0.5198],
        [0.8294, 0.9869],
        [0.2987, 0.0279]])
"""
# 注意数据类型是 uint8, 
mask= torch.ones(4,dtype=torch.uint8)
mask[2] = 0
print(mask)
print(t[mask, :])
"""
 tensor([1, 1, 0, 1], dtype=torch.uint8)
 
 tensor([[0.5492, 0.2083],
        [0.3635, 0.5198],
        [0.2987, 0.0279]]) 
"""

# 注意, 数据类型是long
idx = torch.ones(3,dtype=torch.long)
idx[1] = 0
print(idx)
print(t[idx, :])
"""
tensor([1, 0, 1])
tensor([[0.3635, 0.5198],
        [0.5492, 0.2083],
        [0.3635, 0.5198]])
"""

结论

  • mask的数据类型是torch.uint8时,此时的tensor用作mask,tensor中的1对应的行/列保留,0对应的行/列舍去。且被mask的维度必须与原始tensor的维度一致。其实很好理解,因为你是一个mask,是要覆盖在原始tensor上面的,因此需要你和原始tensor保持一致的dimension。上面的例子中,需要保证 mask.size(0) == t.shape(0),否则会报错。
  • idx的数据类型是torch.long时,此时的tensor用作index,tensor中的每个数字代表着将要取出的tensor的行列索引。用作index时是为了从原始的tensor中取出指定的行列,因此,取出多少不受限。就上面的例子而言不需要保证 idx.size(0) == t.shape(0)

你可能感兴趣的:(pytorch,数据结构)