PyTorch入门——数组维度及squeeze、unsqueeze

初学PyTorch,觉得这里的数组维度系统比起matlab要更加复杂,自己的理解还是不够透彻,这个问题在接触到squeeze、unsqueeze两个函数是更加凸显出来。

今天忽然感觉理解精进了一些,执笔记录。

数组、矩阵、向量

这三个词的区别问题之前已经抛出来过,这里专门再提一下。

  • 向量:一维数组
  • 矩阵:二维数组
  • 数组:包括一维、二维数组、高维数组

广义上来说,数组包括矩阵和向量,矩阵也包括向量。只不过因为我们省略了长度为1的维度,而提出了向量、矩阵两个说法。

在matlab里面向量和矩阵的差别基本可以忽略,高维数组由于需要分成n个二维数组来展示,所以可能不那么直观,更高维度的矩阵其实用得也比较少。

Tensor

Tensor的中文翻译叫”张量“。我认为造成理解上障碍的“罪魁祸首”就是PyTorch里面定义数组时没有分号;来表示纵向的扩展,我们虽然有时还是会说几行几列,但实际上根本不存在明确的“行列”概念,我觉得称之为“嵌套”更加合适。

a = torch.rand(1, 2, 2, 2)
print(a)
print(a.size())
------------------
tensor([[[[0.4160, 0.7013],
          [0.3883, 0.6756]],

         [[0.6303, 0.0144],
          [0.1027, 0.8542]]]])
torch.Size([1, 2, 2, 2])

看上面的例子,4维数组,如果是在matlab里面,会拆成n个二维数组(矩阵)展示出来。但这里是以嵌套的形式一次性输出的,起初我觉得这样有点乱,但是今天我恍然发现,输出的中括号的层次完美地表现了数组的层次!

PyTorch里面高维数组定义和matlab里面不同:

matlab中前两个数分别是行、列数,后面则是更高维度的维度长度,拿上面的(1, 2, 2, 2)举例子,应该是有2*21*2,即一行二列的矩阵;

PyTorch中的各个数字从左往右分别对应由外到内的层次。(a,b,c,d)表示最外层数组包含a个元素(元素本身也是数组,只有到最内层才是普通的数),a个元素各自又包含b个子数组,b个子数组各自又包含c个子数组,c个子数组每个都包含d个元素。

下面对照着括号来看:

  • 最外层的中括号:数组的象征,无论多少维的数组,最外面都有一层中括号,所以最外面的中括号本身并不包含任何维度信息。
  • 第二层:在垂直方向上观察,只有一层,说明该维度长度为1(对应(a,b,c,d)中的a);
  • 第三层:在垂直方向上观察,有两层,说明该维度长度为2(对应(a,b,c,d)中的b);
  • 第四层:在垂直方向上观察,由于上一层是有两层的,所以pycharm的输出也非常贴心地分成了上下两块,我们只用看第一块即可,显然也是两层,说明该维度长度为2(对应(a,b,c,d)中的c);
  • 第五层:也是最内层,随便选中一个中括号,里面是横着排列的,用逗号分隔的2个数,说明该维度长度为2(对应(a,b,c,d)中的d);

上述过程完全没有考虑“行列”的概念,我认为这样在PyTorch中反而不容易迷糊。下面再举一个例子。

a = torch.rand(2, 3, 4, 5)
print(a)
print(a.size())
-----------------------
tensor([[[[0.3689, 0.0515, 0.4698, 0.8608, 0.2939],
          [0.6328, 0.7154, 0.4514, 0.8239, 0.3587],
          [0.5970, 0.1604, 0.0033, 0.8885, 0.3629],
          [0.2186, 0.9431, 0.2264, 0.3357, 0.3118]],

         [[0.8402, 0.7487, 0.8137, 0.0692, 0.9861],
          [0.1275, 0.5480, 0.3803, 0.3801, 0.6754],
          [0.7389, 0.3532, 0.5560, 0.4056, 0.2368],
          [0.1113, 0.3072, 0.6570, 0.1285, 0.6331]],

         [[0.0912, 0.3514, 0.2731, 0.9596, 0.1936],
          [0.3107, 0.4428, 0.9672, 0.0778, 0.6484],
          [0.3629, 0.7911, 0.9783, 0.7051, 0.5235],
          [0.1730, 0.8745, 0.1580, 0.3193, 0.8202]]],


        [[[0.4327, 0.4810, 0.0056, 0.8400, 0.3263],
          [0.1467, 0.8376, 0.0766, 0.5909, 0.4188],
          [0.3555, 0.7011, 0.2004, 0.2605, 0.5205],
          [0.6036, 0.8388, 0.0610, 0.1489, 0.9452]],

         [[0.5051, 0.0161, 0.3363, 0.1939, 0.9949],
          [0.7931, 0.2976, 0.6276, 0.3221, 0.1810],
          [0.7623, 0.5226, 0.7116, 0.4818, 0.5510],
          [0.9556, 0.3049, 0.3479, 0.9650, 0.6561]],

         [[0.8363, 0.0121, 0.5926, 0.7543, 0.4924],
          [0.2830, 0.3250, 0.7983, 0.2548, 0.3496],
          [0.2930, 0.8676, 0.3479, 0.1776, 0.7081],
          [0.4618, 0.9499, 0.9068, 0.7570, 0.2282]]]])
torch.Size([2, 3, 4, 5])

采用层次的眼光来看这个输出结果,瞬间觉得清晰无比!最外层的中括号,是数组的象征,不用管,下一层,竖直方向有两个中括号,对应第一个2;下一层,竖直方向有3个中括号,对应第2个3;再下一层,4个中括号,对应4;最内层,5个数,对应最后一个5。

如果还保留着“行列”的概念,我觉得会被这一层一层的中括号绕晕了。

配合着size的输出来看,更加助于理解。

squeeze、unsqueeze

我觉得如果没有上面的概念,看squeeze、unsqueeze两个函数的用法,只能明白意思:一个降维一个升维,但具体细节(dim的取值)根本难以理解。

今天上午我盯着不同dim情况下unsqueeze的输出结果看了白天,完全是云里雾里!!到底增加的长度为1的维度加在哪里了?

这里其实有两个问题,一是不知道dim到底怎么起作用;二是不知道怎么看Tensor的维度,不然肯定是可以知道多出来的1个维度加在哪里的。

第二个问题很简单,就是用上面的size函数,直接可以输出Tensor的维度信息。

第一个问题,我受到了该回答的启发。下面详细介绍两个函数。

squeeze

作用:消去一个或多个维度。只有在指定的维度长度=1时才生效

用法:

  • a.squeeze(0):消去第1个维度(这里的索引是从0开始的,和数组索引一样)
  • a.squeeze(0,2):消去第1、3个维度
  • a.squeeze():消去所有长度=1的维度

再次声明:如果指定的维度长度不为1,不报错但是也不会有任何作用。

a = torch.rand(1, 2, 1, 4, 5)
b = a.squeeze(0)
c = a.squeeze(1)
d = a.squeeze()
print(a.size())
print(b.size())
print(c.size())
print(d.size())
-------------------
torch.Size([1, 2, 1, 4, 5])  # 原数组维度
torch.Size([2, 1, 4, 5])  # 第一个维度长度=1,被消去
torch.Size([1, 2, 1, 4, 5])  # 第2个维度长度=2,不起作用
torch.Size([2, 4, 5])  # 第1、3两个长度=1的维度都被消去

具体的数组我就不输出了,我觉得还不如直接看size来得清楚。我就不多解释了,结合上面的用法说明,相信大家多能看明白。

注意:对于一个n维矩阵,dim取值的范围是-n~n-1n-1的由来大家很清楚,就是因为索引从0开始,-n的由来是因为索引也可以采用倒序的方式,对于上面的a数组,a.squeeze(-1)表示消去长度=5的维度(当然不会起作用),a.squeeze(-5)表示消去最左边长度=1的维度。这种用法可能会用于数组的长度太长,需要倒序索引的情况,但我相信这种情况可能比较少。

unsqueeze

作用:增加一个长度=1的维度

用法:a.unsqueeze(t),在第t+1个空位(两个维度之间有一个空位,第一个维度前面和最后一个维度后面各有一个空位)增加一个维度。

注意:一次只能加一个,也就是说只能有一个输入,而不能a.unsqueeze(m,n)

a = torch.rand(2, 2, 2, 2, 2)
b = a.unsqueeze(0)
c = a.unsqueeze(1)
d = a.unsqueeze(2)
----------------
torch.Size([2, 2, 2, 2, 2])
torch.Size([1, 2, 2, 2, 2, 2])
torch.Size([2, 1, 2, 2, 2, 2])
torch.Size([2, 2, 1, 2, 2, 2])

这里的dim取值范围为-(n+1)~n,正负范围都比squeeze要宽一个数,很好理解:5个人站一排,中间有6个空位。到了这里大家应该更加理解我前面的”第t+1个空位“是什么意思了

总结

刚接触unsqueeze,我看着输入不用的dim值得到的输出结果陷入沉思,试图用肉眼解析出dim到底是控制了啥?维度到底加在哪里了?甚至发现dim的取值范围是可以变化的!更加有些崩溃。

我认为最重要的理解上的突破就是从嵌套关系、嵌套层次的角度去理解Tensor高维数组,摒弃了“行列”的概念。我称之为理解了“维度”的概念。

理解了“维度”的概念后,回过头看usqueeze和unsqueeze,在指定位置加减维度而已,根本不用把数组输出来,直接看size,之前的困惑一扫而空。

你可能感兴趣的:(python)