关于 Numpy 以及 Pytorch 的数组shape的一点总结

文章目录

      • 1.数组中数据存储的结构
      • 2.数组的坐标问题
      • 3.对于Pytorch 的shape相关问题
      • 4. Pytorch 中几个常见的有关维度的函数
        • 4.1 squeeze() 和 unsqueeze()
        • 4.2 permute() 函数
        • 4.3 transpose()函数

不知道大家有没有类似的问题,处理数据的时候很多时候会被各种数组的 shape 的变化搞晕,但是这方面的资料又不太好找,这里记录一点我遇到的这方面的知识点。

1.数组中数据存储的结构

首先从数据的排列来看,从一个简单的例子来看:

a = np.linspace(1, 24, 24).reshape(2, 3, 4)
print(a)

# 输出结果:
[[[ 1.  2.  3.  4.]
  [ 5.  6.  7.  8.]
  [ 9. 10. 11. 12.]]

 [[13. 14. 15. 16.]
  [17. 18. 19. 20.]
  [21. 22. 23. 24.]]]

首先 1 - 24 的24个数字,排列成 shape 为 (2, 3, 4) 的数组,可以看到这种数据排列的方式显然是按照从最低维(这里是一个三维数组,0,1,2)2开始排列数据,这样至少确定了在进行reshape的时候数据是如何填充的,我的猜想是reshape可以看作是两步:

  • 1.首先将数组整个拉平(既然填充的时候是现在高维度上进行填充,那么拉平的过程就是反向的了)
  • 然后按照新的 shape 按照上面的方式进行填充

2.数组的坐标问题

举个例子来看非常明显:

a = np.linspace(1, 24, 24).reshape(2, 3, 4)

print(a[1][0][0])
print(a[0][1][0])
print(a[0][0][1])

# 可以看到输出分别为:
13.0
5.0
2.0

也就是最低维是行(向下增加),第二维是列,第三维是深。

3.对于Pytorch 的shape相关问题

Numpy 中的数组 array 也就是对应 Pytorch 的 tensor ,那么上面的方式应该在 pytorch 中也是适用的,测试:

b = torch.from_numpy(a)
print(b)
print(b[1][0][0])
print(b[0][1][0])
print(b[0][0][1])

# 输出为:
tensor([[[ 1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12.]],

        [[13., 14., 15., 16.],
         [17., 18., 19., 20.],
         [21., 22., 23., 24.]]], dtype=torch.float64)
tensor(13., dtype=torch.float64)
tensor(5., dtype=torch.float64)
tensor(2., dtype=torch.float64)

测试一下 Pytorch 的 reshape 的函数,Pytorch 中经常使用 view()函数来进行改变数组的 size:

c = b.view(4, 3, 2)
print(c)

# 输出结果:
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]], dtype=torch.float64)

确实与 Numpy 的 reshape 方式一致

4. Pytorch 中几个常见的有关维度的函数

4.1 squeeze() 和 unsqueeze()

这是一对操作相反的函数,分别用于 降维/升维 操作,例如:

a = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a.size())
print(a)

b = a.unsqueeze(0)
print(b.size())
print(b)


# 输出结果为:
torch.Size([2, 3])
tensor([[1, 2, 3],
        [4, 5, 6]])
        
torch.Size([1, 2, 3])
tensor([[[1, 2, 3],
         [4, 5, 6]]])

两个函数都需要传入 dim 参数,指定在哪一个维度上进行 压缩 或者 升维,上面的例子中可以看到在第0维上增加了一个维度。当然这对数据本身并没有什么改变,但是 Pytorch 中对输入的数据格式都做了要求,往往要求是一个四元组,所以很多时候都要对原始数据进行一些改变,所以就用得到这个函数了。

与之对应:

c = b.squeeze(0)
print(c.size())
print(c)

# 输出:
torch.Size([2, 3])
tensor([[1, 2, 3],
        [4, 5, 6]])

不过因为是压缩,所以只有在维度为1的时候才生效。

4.2 permute() 函数

从函数名就可以看出这个函数用于重新排列,这与 reshape 之类的函数区别就在于这是是直接进行重新排列,例如:

对于 [[1, 2],
      [3, 4],
      [5, 6]]
      
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = a.permute(1, 0)

print(b)
# 输出:
tensor([[1, 3, 5],
        [2, 4, 6]])

可以看到这是相当于进行了一个转置,也就是维度上的重新排列,如果得到 (2, 3) 的数组,使用view()函数则是:

c = a.view(2, 3)
print(c)

# 输出:
tensor([[1, 2, 3],
        [4, 5, 6]])

可以很明显地看到区别。

对于二维的数据,这个函数就是简单地转置,但是对于三维数据,有点抽象,举个例子来说:

a = torch.arange(1, 25)
a = a.view(2, 3, 4)

print(a)
b = a.permute(2, 1, 0)

print(b)

# 输出结果:

tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]])
tensor([[[ 1, 13],
         [ 5, 17],
         [ 9, 21]],

        [[ 2, 14],
         [ 6, 18],
         [10, 22]],

        [[ 3, 15],
         [ 7, 19],
         [11, 23]],

        [[ 4, 16],
         [ 8, 20],
         [12, 24]]])

这里相当于将所有的维度进行了相反的排列,看数据的排列就会发现这就像是矩阵的立体结构进行了旋转,但是内部的数据排列又要改变位置,总之数据要符合原来的排列顺序,这个也称它为高维矩阵的转置,也就是熟知的二维数据转置的推广。(这已经超越我语言的极限了…╮(╯_╰)╭),基本就是这个意思。

更详细的可以参考:Pytorch之permute函数

4.3 transpose()函数

transpose()函数是经常使用的转置函数,用于。。。转置 ( ̄_ ̄|||)
但是与上面提到的 permute 函数不同的是, transpose 函数只能在两个维度上进行转置,所以如果某次操作涉及三个以及更多维度的转置的话,使用 permute 函数更方便,当然,连续使用多次 transpose 函数也可以达到同样的目的,例如:

a = torch.randn(2, 3, 4)

# transpose() 函数可以进行两个维度的转置 所以连续使用多次transpose()函数等同于对应的permute()函数
b = a.permute(1, 2, 0)
c = a.transpose(0, 1)
c = c.transpose(1, 2)

print(b.equal(c))

print(b.size())
print(c.size())


输出:
True
torch.Size([3, 4, 2])
torch.Size([3, 4, 2])

可以看到,两种函数实现了一样的结果。

你可能感兴趣的:(Pytorch)