torch.view()、transpose()和permute()的联系与区别

最近被pytorch的几种Tensor维度转换方式搞得头大,故钻研了一下,将钻研历程和结果简述如下

注意:torch.__version__ == '1.2.0’

torch.transpose()和torch.permute()

两者作用相似,都是用于交换不同维度的内容。但其中torch.transpose()是交换指定的两个维度的内容,permute()则可以一次性交换多个维度。具体情况如code所示:
transpose(): 两个维度的交换

 >>> a = torch.Tensor([[[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]], 
                  [[-1,-2,-3,-4,-5], [-6,-7,-8,-9,-10], [-11,-12,-13,-14,-15]]])
 >>> a.shape
 torch.Size([2, 3, 5])
 >>> print(a)
 tensor([[[  1.,   2.,   3.,   4.,   5.],
         [  6.,   7.,   8.,   9.,  10.],
         [ 11.,  12.,  13.,  14.,  15.]],

        [[ -1.,  -2.,  -3.,  -4.,  -5.],
         [ -6.,  -7.,  -8.,  -9., -10.],
         [-11., -12., -13., -14., -15.]]])
 >>> b = a.transpose(1,2)  # 使用transpose,将维度1和2进行交换。这个很好理解。转换后tensor与其shape如下
 >>> print(b, b.shape)
 (tensor([[[  1.,   6.,  11.],
         [  2.,   7.,  12.],
         [  3.,   8.,  13.],
         [  4.,   9.,  14.],
         [  5.,  10.,  15.]],

        [[ -1.,  -6., -11.],
         [ -2.,  -7., -12.],
         [ -3.,  -8., -13.],
         [ -4.,  -9., -14.],
         [ -5., -10., -15.]]]),
torch.Size([2, 5, 3])))

permute():一次性做任意维度的交换

 >>> c = a.permute(2, 0, 1)
 >>> print(c, c.shape)  # 此举将原维度0,1,2的次序变为2,1,0,所以shape也发生了相应的变化。
 (tensor([[[  1.,   6.,  11.],
          [ -1.,  -6., -11.]],
 
         [[  2.,   7.,  12.],
          [ -2.,  -7., -12.]],
 
         [[  3.,   8.,  13.],
          [ -3.,  -8., -13.]],
 
         [[  4.,   9.,  14.],
          [ -4.,  -9., -14.]],
 
         [[  5.,  10.,  15.],
          [ -5., -10., -15.]]]),
 torch.Size([5, 2, 3]))

transpose()和permute()之间的转化:

>>> b = a.permute(2,0,1)
>>> c = a.transpose(1,2).transpose(0,1)
>>> print(b == c, b.shape)
(tensor([[[True, True, True],
          [True, True, True]],
 
         [[True, True, True],
          [True, True, True]],
 
         [[True, True, True],
          [True, True, True]],
 
         [[True, True, True],
          [True, True, True]],
 
         [[True, True, True],
          [True, True, True]]]),
 torch.Size([5, 2, 3]))

如代码所示,先将Tensor a的1,2维度进行交换,再将得到的Tensor的0,1维度再交换,得到的结果和permute是一样的。


transpose()和view()

view()是个在pytorch中很常见的函数。该函数也起到转换Tensor维度的作用,但它转换的方式和transpose()/permute()截然不同。如果说tranpose()是按照Tensor的原有维度忠实地进行交换,那么view()就直接而且简单的多——首先,view()函数会将Tensor所有维度拉平成一维,然后再根据传入的的维度信息重构出一个Tensor。code如下:

# 还是上面的Tensor a
 >>> print(a.shape)
 torch.Size([2, 3, 5])
 >>> print(a.view(2,5,3))
 tensor([[[  1.,   2.,   3.],
         [  4.,   5.,   6.],
         [  7.,   8.,   9.],
         [ 10.,  11.,  12.],
         [ 13.,  14.,  15.]],

        [[ -1.,  -2.,  -3.],
         [ -4.,  -5.,  -6.],
         [ -7.,  -8.,  -9.],
         [-10., -11., -12.],
         [-13., -14., -15.]]])
  >>> c = a.transpose(1,2)
 >>> print(c, c.shape)
(tensor([[[  1.,   6.,  11.],
          [  2.,   7.,  12.],
          [  3.,   8.,  13.],
          [  4.,   9.,  14.],
          [  5.,  10.,  15.]],
 
         [[ -1.,  -6., -11.],
          [ -2.,  -7., -12.],
          [ -3.,  -8., -13.],
          [ -4.,  -9., -14.],
          [ -5., -10., -15.]]]),
 torch.Size([2, 5, 3]))

如代码所示。即使view()transpose()最终得到的Tensor的shape是一样的,但二者内容并不相同。view函数只是按照给定的(2,5,3)的Tensor维度,将元素按顺序一个个填进去;而transpose函数,则的确是在进行第一个第二维度的转置


此外,有些情况下转置(transpose)后的Tensor是无法被view的,原因在于,转置后的Tensor不是“连续”(non-contiguous)的。关于contiguous array的问题numpy里也一样,在这里有个很棒的解释

你可能感兴趣的:(#,pytorch,数学,python,线性代数)