PyTorch 中两大核心转置函数 transpose() 和 permute()

关心差别的可以直接看【3.不同点】

前言

pytorch中转置用的函数就只有这两个

  1. transpose()
  2. permute()

这两个函数都是交换维度的操作。有一些细微的区别

1. 官方文档

transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor

函数返回输入矩阵input的转置。交换维度dim0dim1

参数:

  • input (Tensor) – 输入张量,必填
  • dim0 (int) – 转置的第一维,默认0,可选
  • dim1 (int) – 转置的第二维,默认1,可选
permute()
permute(dims) → Tensor

将tensor的维度换位。

参数:

  • dims (int…*)-换位顺序,必填

2. 相同点

  1. 都是返回转置后矩阵。
  2. 都可以操作高纬矩阵,permute在高维的功能性更强。

3.不同点

先定义我们后面用的数据如下

# 创造二维数据x,dim=0时候2,dim=1时候3
x = torch.randn(2,3)       'x.shape  →  [2,3]'
# 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
y = torch.randn(2,3,4)   'y.shape  →  [2,3,4]'
  1. 合法性不同

torch.transpose(x)合法, x.transpose()合法。
tensor.permute(x)合法,x.permute()合法。

参考第二点的举例

  1. 操作dim不同:

transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数,因为permute()的参数是int*

举例

# 对于transpose
x.transpose(0,1)     'shape→[3,2] '  
x.transpose(1,0)     'shape→[3,2] '  
y.transpose(0,1)     'shape→[3,2,4]' 
y.transpose(0,2,1)  'error,操作不了多维'

# 对于permute()
x.permute(0,1)     'shape→[2,3]'
x.permute(1,0)     'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) '
y.permute(0,1)     "error 没有传入所有维度数"
y.permute(1,0,2)  'shape→[3,2,4]'
  1. transpose()中的dim没有数的大小区分;permute()中的dim有数的大小区分

举例,注意后面的shape

# 对于transpose,不区分dim大小
x1 = x.transpose(0,1)   'shape→[3,2] '  
x2 = x.transpose(1,0)   '也变换了,shape→[3,2] '  
print(torch.equal(x1,x2))
' True ,value和shape都一样'

# 对于permute()
x1 = x.permute(0,1)     '不同transpose,shape→[2,3] '  
x2 = x.permute(1,0)     'shape→[3,2] '  
print(torch.equal(x1,x2))
'False,和transpose不同'

y1 = y.permute(0,1,2)     '保持不变,shape→[2,3,4] '  
y2 = y.permute(1,0,2)     'shape→[3,2,4] '  
y3 = y.permute(1,2,0)     'shape→[3,4,2] '  

4.总结

最重要的区别应该是上面的第三点了。

另外,简单的数据用transpose()就可以了,但是个人觉得不够直观,指向性弱了点;复杂维度的可以用permute(),对于维度的改变,一般更加精准。

你可能感兴趣的:(pytorch,python)