一维情况
import torch
a = torch.randn((8,))
b = torch.randint(0,2,(8,))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")
输出结果为:
a = tensor([-0.5343, -1.2582, -0.4511, -0.4338, 2.2691, 0.4879, 0.6847, 0.6235])
slice_ = tensor([ True, True, False, False, False, False, False, False])
a[slice_] = tensor([-0.5343, -1.2582])
多维情况
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3,4))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")
输出结果为:
a = tensor([[[ 1.9325, 1.2950, 1.2434, 0.0564],
[ 0.3010, 0.0343, 0.7497, 0.4019],
[ 0.3159, -2.3188, 0.3495, -0.3471]],
[[ 1.0270, 0.9790, 0.9406, 0.3484],
[ 0.7881, 0.7568, 1.8638, 0.4024],
[-0.5964, -2.3572, -0.6636, 0.8282]]])
slice_ = tensor([[[False, True, True, False],
[ True, True, True, False],
[ True, True, False, False]],
[[ True, True, False, True],
[ True, True, True, False],
[False, True, False, False]]])
a[slice_] = torch.Size([14])
当两者维度相同时,在torch.tensor里面和true、false矩阵中true位置相同的元素会被保留,其余的值丢掉,最终切片结果为一维张量
首先要说明的是,两者维度可以不相同,但是每一个维度的值必须相同,比如说一个张量shape为(2,3,4,5),那true、false矩阵的shape可以是(2,)、(2,3)、(2,3,4)、(2,3,4,5)四种情况
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3))
slice_ = (b==0)
print(f"a = {a}, a.shape = {a.shape}")
print(f"slice_ = {slice_}, slice_.shape = {slice_.shape}")
print(f"a[slice_] = {a[slice_]}, a[slice_].shape = {a[slice_].shape}")
输出结果为
a = tensor([[[-0.2952, -0.2619, -0.8608, 1.2657],
[ 0.1895, -0.4806, -1.5506, 0.2752],
[-0.2219, 0.2185, -0.7038, 0.1399]],
[[-1.9745, 0.7333, -1.0359, 1.4674],
[ 1.6730, 0.1612, 0.3537, -0.1737],
[-0.9188, 3.0544, 1.4211, 0.9257]]]), a.shape = torch.Size([2, 3, 4])
slice_ = tensor([[ True, True, True],
[ True, False, False]]), slice_.shape = torch.Size([2, 3])
a[slice_] = tensor([[-0.2952, -0.2619, -0.8608, 1.2657],
[ 0.1895, -0.4806, -1.5506, 0.2752],
[-0.2219, 0.2185, -0.7038, 0.1399],
[-1.9745, 0.7333, -1.0359, 1.4674]]), a[slice_].shape = torch.Size([4, 4])
当两者维度不相同时,在上面例子中,torch.tensor的shape为(2,3,4),true、false矩阵shape为(2,3),两者的前两个维度相同,那torch.tensor就保留true位置的元素,只不过这个被保留元素的shape为(4,),又因为存在四个Ture,所以最后切片结果shape为(4,4)