最近被pytorch的几种Tensor维度转换方式搞得头大,故钻研了一下,将钻研历程和结果简述如下
注意:torch.__version__ == '1.2.0’
两者作用相似,都是用于交换不同维度的内容。但其中torch.transpose()
是交换指定的两个维度的内容,permute()
则可以一次性交换多个维度。具体情况如code所示:
transpose(): 两个维度的交换
>>> a = torch.Tensor([[[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]],
[[-1,-2,-3,-4,-5], [-6,-7,-8,-9,-10], [-11,-12,-13,-14,-15]]])
>>> a.shape
torch.Size([2, 3, 5])
>>> print(a)
tensor([[[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[ 11., 12., 13., 14., 15.]],
[[ -1., -2., -3., -4., -5.],
[ -6., -7., -8., -9., -10.],
[-11., -12., -13., -14., -15.]]])
>>> b = a.transpose(1,2) # 使用transpose,将维度1和2进行交换。这个很好理解。转换后tensor与其shape如下
>>> print(b, b.shape)
(tensor([[[ 1., 6., 11.],
[ 2., 7., 12.],
[ 3., 8., 13.],
[ 4., 9., 14.],
[ 5., 10., 15.]],
[[ -1., -6., -11.],
[ -2., -7., -12.],
[ -3., -8., -13.],
[ -4., -9., -14.],
[ -5., -10., -15.]]]),
torch.Size([2, 5, 3])))
permute():一次性做任意维度的交换
>>> c = a.permute(2, 0, 1)
>>> print(c, c.shape) # 此举将原维度0,1,2的次序变为2,1,0,所以shape也发生了相应的变化。
(tensor([[[ 1., 6., 11.],
[ -1., -6., -11.]],
[[ 2., 7., 12.],
[ -2., -7., -12.]],
[[ 3., 8., 13.],
[ -3., -8., -13.]],
[[ 4., 9., 14.],
[ -4., -9., -14.]],
[[ 5., 10., 15.],
[ -5., -10., -15.]]]),
torch.Size([5, 2, 3]))
transpose()和permute()之间的转化:
>>> b = a.permute(2,0,1)
>>> c = a.transpose(1,2).transpose(0,1)
>>> print(b == c, b.shape)
(tensor([[[True, True, True],
[True, True, True]],
[[True, True, True],
[True, True, True]],
[[True, True, True],
[True, True, True]],
[[True, True, True],
[True, True, True]],
[[True, True, True],
[True, True, True]]]),
torch.Size([5, 2, 3]))
如代码所示,先将Tensor a的1,2维度进行交换,再将得到的Tensor的0,1维度再交换,得到的结果和permute是一样的。
view()
是个在pytorch中很常见的函数。该函数也起到转换Tensor维度的作用,但它转换的方式和transpose()/permute()截然不同。如果说tranpose()
是按照Tensor的原有维度忠实地进行交换,那么view()
就直接而且简单的多——首先,view()函数会将Tensor所有维度拉平成一维,然后再根据传入的的维度信息重构出一个Tensor。code如下:
# 还是上面的Tensor a
>>> print(a.shape)
torch.Size([2, 3, 5])
>>> print(a.view(2,5,3))
tensor([[[ 1., 2., 3.],
[ 4., 5., 6.],
[ 7., 8., 9.],
[ 10., 11., 12.],
[ 13., 14., 15.]],
[[ -1., -2., -3.],
[ -4., -5., -6.],
[ -7., -8., -9.],
[-10., -11., -12.],
[-13., -14., -15.]]])
>>> c = a.transpose(1,2)
>>> print(c, c.shape)
(tensor([[[ 1., 6., 11.],
[ 2., 7., 12.],
[ 3., 8., 13.],
[ 4., 9., 14.],
[ 5., 10., 15.]],
[[ -1., -6., -11.],
[ -2., -7., -12.],
[ -3., -8., -13.],
[ -4., -9., -14.],
[ -5., -10., -15.]]]),
torch.Size([2, 5, 3]))
如代码所示。即使view()
和transpose()
最终得到的Tensor的shape是一样的,但二者内容并不相同。view函数只是按照给定的(2,5,3)的Tensor维度,将元素按顺序一个个填进去;而transpose函数,则的确是在进行第一个第二维度的转置。
此外,有些情况下转置(transpose)后的Tensor是无法被view的,原因在于,转置后的Tensor不是“连续”(non-contiguous)的。关于contiguous array的问题numpy里也一样,在这里有个很棒的解释