在pytorch
中转置用的函数就只有这两个
transpose()
permute()
这两个函数都是交换维度的操作。有一些细微的区别
transpose()
torch.transpose(input, dim0, dim1, out=None) → Tensor
函数返回输入矩阵input
的转置。交换维度dim0
和dim1
参数:
permute()
permute(dims) → Tensor
将tensor的维度换位。
参数:
permute
在高维的功能性更强。先定义我们后面用的数据如下
# 创造二位数据x,dim=0时候2,dim=1时候3
x = torch.randn(2,3) 'x.shape → [2,3]'
# 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
y = torch.randn(2,3,4) 'y.shape → [2,3,4]'
1、合法性不同
torch.transpose(x)
合法, x.transpose()
合法。tensor.permute(x)
不合法,x.permute()
合法。
2、操作dim
不同:(和numpy的swapaxes、transpose()类似)
transpose()
只能一次操作两个维度;permute()
可以一次操作多维数据,且必须传入所有维度数,因为permute()
的参数是int*
。
举例
# 对于transpose
x.transpose(0,1) 'shape→[3,2] '
x.transpose(1,0) 'shape→[3,2] '
y.transpose(0,1) 'shape→[3,2,4]'
y.transpose(0,2,1) '报错,操作不了多维'
# 对于permute()
x.permute(0,1) 'shape→[2,3]'
x.permute(1,0) 'shape→[3,2]'
y.permute(0,1) "报错,number of dims don't match in permute"
y.permute(1,0,2) 'shape→[3,2,4],必须写满所有的维度才不报错'
1、transpose()
中的dim
没有数的大小区分;permute()
中的dim
有数的大小区分
(原因见blog《numpy中两大核心转置函数 transpose() 和 swapaxes()(类似pytorch的交换维度、转置)》)
举例,注意后面的shape
:
# 对于transpose,不区分dim大小
x1 = x.transpose(0,1) 'shape变换了,shape→[3,2] '
x2 = x.transpose(1,0) '也变换了,shape→[3,2] '
print(torch.equal(x1,x2))
' True'
# 对于permute()
x1 = x.permute(0,1) '不同transpose,shape→[2,3] '
x2 = x.permute(1,0) 'shape→[3,2] '
print(torch.equal(x1,x2))
'False'
y1 = y.permute(0,1,2) '保持不变,shape→[2,3,4] '
y2 = y.permute(1,0,2) 'shape→[3,2,4] '