pytorch tranpose与permute函数

transposepermute都是转置函数,可以交换Tensor的维度。

1. transpose

torch.transpose(input, dim0, dim1, out=None) → \rightarrow Tensor

transpose函数用于交换input的维度dim0dim1,只能交换两个维度,且dim0和dim1的参数位置没有顺序而言。

例子:

a = torch.arange(6).reshape((2, 3))
a, a.transpose(1, 0) # shape从(2, 3)变成(3, 2)

Out:tensor([[0, 1, 2],
            [3, 4, 5]])
    tensor([[0, 3],
            [1, 4],
            [2, 5]])

a.transpose(1, 0)a.transpose(0, 1)相同,都是将第0个维度和第1个维度交换。

a.transpose(1, 0) == a.transpose(0, 1)

Out: tensor([[True, True],
            [True, True],
            [True, True]])

作用于高维:

b = torch.arange(24).reshape((2, 3, 4))
b.transpose(1, 2)		# shape从(2, 3, 4)变成(2, 4, 3)
b.transpose(0, 1, 2)	# 报错,只能输入两个维度进行交换

2. permute

torch.permute(input, dims) → \rightarrow Tensor

permute函数相当于把input的各个维度进行了重排列,可以一次性交换多个维度,参数dims的长度必须与input的维度相同。

例子:

print(torch.permute(a, (1, 0)))			# shape从(2, 3) 变成 (3, 2)
print(torch.permute(a, (0, 1)) == a)	# 不变

Out:tensor([[0, 3],
            [1, 4],
            [2, 5]])
    tensor([[True, True, True],
            [True, True, True]])

作用于高维:

b.permute(1, 0, 2)	# shape从(2, 3, 4)变成(3, 2, 4)
b.permute(2, 0, 1)	# shape从(2, 3, 4)变成(4, 2, 3)
b.permute(1, 0)		# 报错,必须要输入三个维度

由于permute一次可以操作多个维度,因此在高维的功能性比transpose更强,不过permute能做到的transpose也能做到,只不过transpose可能要多调用几次,transpose能做到的permute也都能做到。

你可能感兴趣的:(python,pytorch,深度学习,python)