今天在Debug代码的时候,遇到了由slice函数对象组成的list,对一个3维Tensor进行切片操作,经过一段时间查询资料,未能找到相关资料,自己动手写程序验证了下想法,此博客记录下验证过程。
我debug的程序生成了两个slice对象,并组合成了list(这篇博客后面成为slice list)在一起:slices=[slice(None, None, None), slice(0, 1, None)]
,然后有一个shape为[1,61,1]
的tensor(姑且叫做a吧),经过a[slices]
索引后,得到了一个shape为[1,1,1]
的tensor,这个tensor是tensor a 的第一个元素。接下来的文章,我会通过程序验证,多维slice是如何对多维tensor进行切片的。
先说下结论:
[x,y,z,u]
,那么切片的list最多包含4个slice对象。接下来是实验部分:
创建一个shape为[2,2,3,3]的tensor b,让其中的元素从1线性增加(为了方便查看索引后的元素在tensor中的位置)
>>> a = torch.linspace(1,2*2*3*3, 2*2*3*3)
>>> a
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28.,
29., 30., 31., 32., 33., 34., 35., 36.])
# reshape为[2,2,3,3]
>>> b = a.reshape(2,2,3,3)
>>> b
tensor([[[[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[13., 14., 15.],
[16., 17., 18.]]],
[[[19., 20., 21.],
[22., 23., 24.],
[25., 26., 27.]],
[[28., 29., 30.],
[31., 32., 33.],
[34., 35., 36.]]]])
然后创建一个slice list,包含两个slice对象。其中,第一个slice对象表示从0开始,到最后一个元素(不包含最后一个元素),step默认为1。第二个slice对象表示从1开始,到第二个元素(不包含第二个元素),step默认为1。
>>> s = [slice(0,-1), slice(1,2)]
>>> s
[slice(0, -1, None), slice(1, 2, None)]
通过打印索引后对象的shape,可以看出,tensor由原来的[2,2,3,3]变成了[1,1,3,3]
>>> b[s].shape
torch.Size([1, 1, 3, 3])
接下来再来看具体slice作用的情况。
原来tensor的shape为[2,2,3,3],其中维度0对应shape值为2,可以理解为tensor最外层有两个大块,分别是两个shape为[2,3,3]的tensor:
第一块
[[[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.]],
[[10., 11., 12.],
[13., 14., 15.],
[16., 17., 18.]]]
第二块:
[[[19., 20., 21.],
[22., 23., 24.],
[25., 26., 27.]],
[[28., 29., 30.],
[31., 32., 33.],
[34., 35., 36.]]]
然后针对shape为[2,3,3]的tensor,其中维度0的含义和刚才一样可以类推,即分成两个shape为[3,3]的tensor块。
分别是这4小块(shape为[3,3]):
[[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.]]
[[10., 11., 12.],
[13., 14., 15.],
[16., 17., 18.]]
[[19., 20., 21.],
[22., 23., 24.],
[25., 26., 27.]]
[[28., 29., 30.],
[31., 32., 33.],
[34., 35., 36.]]
后面的维度依然可以以同样的类比得到更细分的块。
这样分块用来干嘛呢?方面我们理解slice的作用。
接下来,我们看看刚才tensor b 被slice list索引后的结果,如下:
>>> b[s]
tensor([[[[10., 11., 12.],
[13., 14., 15.],
[16., 17., 18.]]]])
仔细观察这个块,他是shape为[2,2,3,3]的tensor的第一大块中的第二小块的值。
那么为什么是第一大块以及,第一大块中的第二小块呢。看看slice操作就能得到答案了。
slice list为[slice(0, -1, None), slice(1, 2, None)]
,第一个slice对象的作用在tensor b的维度0,由于维度0对应shape数值为2,所以索引的结果是[0],即第一大块。第二个slice对象作用中tensor b的维度1,由于维度1对应shape为2,所以得到索引结果为[1],即第一大块中的第二小块。
接下来再做一个实验,验证下。我们想取第二大块里面的第一小块,的第二行的值。slice list应该为[slice(1,2,None), slice(0,1,None), slice(1,None,None)]
,正确的结果应该是[22., 23., 24.]
。接下来在python中验证下结果:
>>> s2 = [slice(1,2,None), slice(0,1,None), slice(1,2,None)]
>>> s2
[slice(1, 2, None), slice(0, 1, None), slice(1, 2, None)]
>>> b[s2]
tensor([[[[22., 23., 24.]]]])
结果得到的和我们预想的一致,所以这就验证了slice list中的每一个slice对象的下标为其作用在tensor上的维度。
这一点性质应该不难理解,tensor有多少维度,就有多少维度能够被进行切片。通过如下程序可以验证:
>>> s2
[slice(1, 2, None), slice(0, 1, None), slice(1, 2, None)]
>>> c = a.reshape(4,9)
>>> c
tensor([[ 1., 2., 3., 4., 5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14., 15., 16., 17., 18.],
[19., 20., 21., 22., 23., 24., 25., 26., 27.],
[28., 29., 30., 31., 32., 33., 34., 35., 36.]])
>>> c[s2]
Traceback (most recent call last):
File "" , line 1, in <module>
IndexError: too many indices for tensor of dimension 2
将a reshape为2为tensor,然后用一个3维slice list对其进行索引,结果程序出错,这就验证了允许用来切片的slice list的维度小于等于tensor的维度。
本博客是记录个人学习slice list对多维tensor进行切片的操作,只是通过程序验证了自己的猜想,但是未能找到相关的其他参考资料,希望有相关参考文献的朋友可以留言给我,感激不尽~