初学PyTorch,觉得这里的数组维度系统比起matlab要更加复杂,自己的理解还是不够透彻,这个问题在接触到squeeze、unsqueeze两个函数是更加凸显出来。
今天忽然感觉理解精进了一些,执笔记录。
这三个词的区别问题之前已经抛出来过,这里专门再提一下。
广义上来说,数组包括矩阵和向量,矩阵也包括向量。只不过因为我们省略了长度为1的维度,而提出了向量、矩阵两个说法。
在matlab里面向量和矩阵的差别基本可以忽略,高维数组由于需要分成n个二维数组来展示,所以可能不那么直观,更高维度的矩阵其实用得也比较少。
Tensor的中文翻译叫”张量“。我认为造成理解上障碍的“罪魁祸首”就是PyTorch里面定义数组时没有分号;
来表示纵向的扩展,我们虽然有时还是会说几行几列,但实际上根本不存在明确的“行列”概念,我觉得称之为“嵌套”更加合适。
a = torch.rand(1, 2, 2, 2)
print(a)
print(a.size())
------------------
tensor([[[[0.4160, 0.7013],
[0.3883, 0.6756]],
[[0.6303, 0.0144],
[0.1027, 0.8542]]]])
torch.Size([1, 2, 2, 2])
看上面的例子,4维数组,如果是在matlab里面,会拆成n个二维数组(矩阵)展示出来。但这里是以嵌套的形式一次性输出的,起初我觉得这样有点乱,但是今天我恍然发现,输出的中括号的层次完美地表现了数组的层次!
PyTorch里面高维数组定义和matlab里面不同:
matlab中前两个数分别是行、列数,后面则是更高维度的维度长度,拿上面的(1, 2, 2, 2)
举例子,应该是有2*2
个1*2
,即一行二列的矩阵;
PyTorch中的各个数字从左往右分别对应由外到内的层次。(a,b,c,d)
表示最外层数组包含a个元素(元素本身也是数组,只有到最内层才是普通的数),a个元素各自又包含b个子数组,b个子数组各自又包含c个子数组,c个子数组每个都包含d个元素。
下面对照着括号来看:
(a,b,c,d)
中的a);(a,b,c,d)
中的b);(a,b,c,d)
中的c);(a,b,c,d)
中的d);上述过程完全没有考虑“行列”的概念,我认为这样在PyTorch中反而不容易迷糊。下面再举一个例子。
a = torch.rand(2, 3, 4, 5)
print(a)
print(a.size())
-----------------------
tensor([[[[0.3689, 0.0515, 0.4698, 0.8608, 0.2939],
[0.6328, 0.7154, 0.4514, 0.8239, 0.3587],
[0.5970, 0.1604, 0.0033, 0.8885, 0.3629],
[0.2186, 0.9431, 0.2264, 0.3357, 0.3118]],
[[0.8402, 0.7487, 0.8137, 0.0692, 0.9861],
[0.1275, 0.5480, 0.3803, 0.3801, 0.6754],
[0.7389, 0.3532, 0.5560, 0.4056, 0.2368],
[0.1113, 0.3072, 0.6570, 0.1285, 0.6331]],
[[0.0912, 0.3514, 0.2731, 0.9596, 0.1936],
[0.3107, 0.4428, 0.9672, 0.0778, 0.6484],
[0.3629, 0.7911, 0.9783, 0.7051, 0.5235],
[0.1730, 0.8745, 0.1580, 0.3193, 0.8202]]],
[[[0.4327, 0.4810, 0.0056, 0.8400, 0.3263],
[0.1467, 0.8376, 0.0766, 0.5909, 0.4188],
[0.3555, 0.7011, 0.2004, 0.2605, 0.5205],
[0.6036, 0.8388, 0.0610, 0.1489, 0.9452]],
[[0.5051, 0.0161, 0.3363, 0.1939, 0.9949],
[0.7931, 0.2976, 0.6276, 0.3221, 0.1810],
[0.7623, 0.5226, 0.7116, 0.4818, 0.5510],
[0.9556, 0.3049, 0.3479, 0.9650, 0.6561]],
[[0.8363, 0.0121, 0.5926, 0.7543, 0.4924],
[0.2830, 0.3250, 0.7983, 0.2548, 0.3496],
[0.2930, 0.8676, 0.3479, 0.1776, 0.7081],
[0.4618, 0.9499, 0.9068, 0.7570, 0.2282]]]])
torch.Size([2, 3, 4, 5])
采用层次的眼光来看这个输出结果,瞬间觉得清晰无比!最外层的中括号,是数组的象征,不用管,下一层,竖直方向有两个中括号,对应第一个2;下一层,竖直方向有3个中括号,对应第2个3;再下一层,4个中括号,对应4;最内层,5个数,对应最后一个5。
如果还保留着“行列”的概念,我觉得会被这一层一层的中括号绕晕了。
配合着size的输出来看,更加助于理解。
我觉得如果没有上面的概念,看squeeze、unsqueeze两个函数的用法,只能明白意思:一个降维一个升维,但具体细节(dim的取值)根本难以理解。
今天上午我盯着不同dim情况下unsqueeze的输出结果看了白天,完全是云里雾里!!到底增加的长度为1的维度加在哪里了?
这里其实有两个问题,一是不知道dim到底怎么起作用;二是不知道怎么看Tensor的维度,不然肯定是可以知道多出来的1个维度加在哪里的。
第二个问题很简单,就是用上面的size函数,直接可以输出Tensor的维度信息。
第一个问题,我受到了该回答的启发。下面详细介绍两个函数。
作用:消去一个或多个维度。只有在指定的维度长度=1时才生效。
用法:
a.squeeze(0)
:消去第1个维度(这里的索引是从0开始的,和数组索引一样)a.squeeze(0,2)
:消去第1、3个维度a.squeeze()
:消去所有长度=1的维度再次声明:如果指定的维度长度不为1,不报错但是也不会有任何作用。
a = torch.rand(1, 2, 1, 4, 5)
b = a.squeeze(0)
c = a.squeeze(1)
d = a.squeeze()
print(a.size())
print(b.size())
print(c.size())
print(d.size())
-------------------
torch.Size([1, 2, 1, 4, 5]) # 原数组维度
torch.Size([2, 1, 4, 5]) # 第一个维度长度=1,被消去
torch.Size([1, 2, 1, 4, 5]) # 第2个维度长度=2,不起作用
torch.Size([2, 4, 5]) # 第1、3两个长度=1的维度都被消去
具体的数组我就不输出了,我觉得还不如直接看size来得清楚。我就不多解释了,结合上面的用法说明,相信大家多能看明白。
注意:对于一个n维矩阵,dim取值的范围是-n~n-1
。n-1
的由来大家很清楚,就是因为索引从0开始,-n
的由来是因为索引也可以采用倒序的方式,对于上面的a数组,a.squeeze(-1)
表示消去长度=5的维度(当然不会起作用),a.squeeze(-5)
表示消去最左边长度=1的维度。这种用法可能会用于数组的长度太长,需要倒序索引的情况,但我相信这种情况可能比较少。
作用:增加一个长度=1的维度
用法:a.unsqueeze(t)
,在第t+1个空位(两个维度之间有一个空位,第一个维度前面和最后一个维度后面各有一个空位)增加一个维度。
注意:一次只能加一个,也就是说只能有一个输入,而不能a.unsqueeze(m,n)
。
a = torch.rand(2, 2, 2, 2, 2)
b = a.unsqueeze(0)
c = a.unsqueeze(1)
d = a.unsqueeze(2)
----------------
torch.Size([2, 2, 2, 2, 2])
torch.Size([1, 2, 2, 2, 2, 2])
torch.Size([2, 1, 2, 2, 2, 2])
torch.Size([2, 2, 1, 2, 2, 2])
这里的dim取值范围为-(n+1)~n
,正负范围都比squeeze要宽一个数,很好理解:5个人站一排,中间有6个空位。到了这里大家应该更加理解我前面的”第t+1个空位“是什么意思了
刚接触unsqueeze,我看着输入不用的dim值得到的输出结果陷入沉思,试图用肉眼解析出dim到底是控制了啥?维度到底加在哪里了?甚至发现dim的取值范围是可以变化的!更加有些崩溃。
我认为最重要的理解上的突破就是从嵌套关系、嵌套层次的角度去理解Tensor高维数组,摒弃了“行列”的概念。我称之为理解了“维度”的概念。
理解了“维度”的概念后,回过头看usqueeze和unsqueeze,在指定位置加减维度而已,根本不用把数组输出来,直接看size,之前的困惑一扫而空。