Tensor多维slice切片操作

今天在Debug代码的时候,遇到了由slice函数对象组成的list,对一个3维Tensor进行切片操作,经过一段时间查询资料,未能找到相关资料,自己动手写程序验证了下想法,此博客记录下验证过程。

我debug的程序生成了两个slice对象,并组合成了list(这篇博客后面成为slice list)在一起:slices=[slice(None, None, None), slice(0, 1, None)],然后有一个shape为[1,61,1]的tensor(姑且叫做a吧),经过a[slices]索引后,得到了一个shape为[1,1,1]的tensor,这个tensor是tensor a 的第一个元素。接下来的文章,我会通过程序验证,多维slice是如何对多维tensor进行切片的。

先说下结论:

  • slice list中的每一个slice对象的下标为其作用在tensor上的维度,并且slice对tensor每一维的操作可视做对普通一维向量操作。即,slice list中的第一个slice对象(下标为0),对tensor的第0维进行切片,以此类推。
  • 允许用来切片的slice list的维度小于等于tensor的维度。即,如果tensor是一个维度为4,shape为[x,y,z,u],那么切片的list最多包含4个slice对象。

接下来是实验部分:

slice list中的每一个slice对象的下标为其作用在tensor上的维度

创建一个shape为[2,2,3,3]的tensor b,让其中的元素从1线性增加(为了方便查看索引后的元素在tensor中的位置)

>>> a = torch.linspace(1,2*2*3*3, 2*2*3*3)
>>> a
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28.,
        29., 30., 31., 32., 33., 34., 35., 36.])
# reshape为[2,2,3,3]
>>> b = a.reshape(2,2,3,3)
>>> b
tensor([[[[ 1.,  2.,  3.],
          [ 4.,  5.,  6.],
          [ 7.,  8.,  9.]],

         [[10., 11., 12.],
          [13., 14., 15.],
          [16., 17., 18.]]],


        [[[19., 20., 21.],
          [22., 23., 24.],
          [25., 26., 27.]],

         [[28., 29., 30.],
          [31., 32., 33.],
          [34., 35., 36.]]]])

然后创建一个slice list,包含两个slice对象。其中,第一个slice对象表示从0开始,到最后一个元素(不包含最后一个元素),step默认为1。第二个slice对象表示从1开始,到第二个元素(不包含第二个元素),step默认为1。

>>> s = [slice(0,-1), slice(1,2)]
>>> s
[slice(0, -1, None), slice(1, 2, None)]

通过打印索引后对象的shape,可以看出,tensor由原来的[2,2,3,3]变成了[1,1,3,3]

>>> b[s].shape
torch.Size([1, 1, 3, 3])

接下来再来看具体slice作用的情况。

原来tensor的shape为[2,2,3,3],其中维度0对应shape值为2,可以理解为tensor最外层有两个大块,分别是两个shape为[2,3,3]的tensor:

第一块

[[[ 1.,  2.,  3.],
 [ 4.,  5.,  6.],
 [ 7.,  8.,  9.]],

[[10., 11., 12.],
 [13., 14., 15.],
 [16., 17., 18.]]]

第二块:

[[[19., 20., 21.],
  [22., 23., 24.],
  [25., 26., 27.]],

 [[28., 29., 30.],
  [31., 32., 33.],
  [34., 35., 36.]]]

然后针对shape为[2,3,3]的tensor,其中维度0的含义和刚才一样可以类推,即分成两个shape为[3,3]的tensor块。

分别是这4小块(shape为[3,3]):

[[ 1.,  2.,  3.],
 [ 4.,  5.,  6.],
 [ 7.,  8.,  9.]]
[[10., 11., 12.],
 [13., 14., 15.],
 [16., 17., 18.]]
[[19., 20., 21.],
  [22., 23., 24.],
  [25., 26., 27.]]
[[28., 29., 30.],
  [31., 32., 33.],
  [34., 35., 36.]]

后面的维度依然可以以同样的类比得到更细分的块。

这样分块用来干嘛呢?方面我们理解slice的作用。

接下来,我们看看刚才tensor b 被slice list索引后的结果,如下:

>>> b[s]
tensor([[[[10., 11., 12.],
          [13., 14., 15.],
          [16., 17., 18.]]]])

仔细观察这个块,他是shape为[2,2,3,3]的tensor的第一大块中的第二小块的值。

那么为什么是第一大块以及,第一大块中的第二小块呢。看看slice操作就能得到答案了。

slice list为[slice(0, -1, None), slice(1, 2, None)],第一个slice对象的作用在tensor b的维度0,由于维度0对应shape数值为2,所以索引的结果是[0],即第一大块。第二个slice对象作用中tensor b的维度1,由于维度1对应shape为2,所以得到索引结果为[1],即第一大块中的第二小块

接下来再做一个实验,验证下。我们想取第二大块里面的第一小块,的第二行的值。slice list应该为[slice(1,2,None), slice(0,1,None), slice(1,None,None)],正确的结果应该是[22., 23., 24.]。接下来在python中验证下结果:

>>> s2 = [slice(1,2,None), slice(0,1,None), slice(1,2,None)]
>>> s2
[slice(1, 2, None), slice(0, 1, None), slice(1, 2, None)]
>>> b[s2]
tensor([[[[22., 23., 24.]]]])

结果得到的和我们预想的一致,所以这就验证了slice list中的每一个slice对象的下标为其作用在tensor上的维度

允许用来切片的slice list的维度小于等于tensor的维度

这一点性质应该不难理解,tensor有多少维度,就有多少维度能够被进行切片。通过如下程序可以验证:

>>> s2
[slice(1, 2, None), slice(0, 1, None), slice(1, 2, None)]
>>> c = a.reshape(4,9)
>>> c
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24., 25., 26., 27.],
        [28., 29., 30., 31., 32., 33., 34., 35., 36.]])
>>> c[s2]
Traceback (most recent call last):
  File "", line 1, in <module>
IndexError: too many indices for tensor of dimension 2

将a reshape为2为tensor,然后用一个3维slice list对其进行索引,结果程序出错,这就验证了允许用来切片的slice list的维度小于等于tensor的维度

最后

本博客是记录个人学习slice list对多维tensor进行切片的操作,只是通过程序验证了自己的猜想,但是未能找到相关的其他参考资料,希望有相关参考文献的朋友可以留言给我,感激不尽~

你可能感兴趣的:(人工智能从入门到放弃,python,深度学习,pytorch)