【PyTorch】transpose() 和 permute() 函数:交换张量维度

在 PyTorch 中,transposepermute 都是用于调整张量维度的函数。它们在很多深度学习任务中非常有用,尤其是在处理张量维度和进行矩阵操作时。

1. transpose 函数

transpose 函数用来交换张量的两个维度。它接受两个参数,即需要交换的两个维度的索引。这个操作不会改变张量的数据本身,只是改变了张量的视图。

语法
torch.transpose(input, dim0, dim1)
  • input:输入的张量。
  • dim0:要交换的第一个维度的索引。
  • dim1:要交换的第二个维度的索引。
返回值

返回一个新的张量,其中 dim0dim1 的维度被交换了。

示例
import torch

# 创建一个 2x3 的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 交换第0维和第1维
y = torch.transpose(x, 0, 1)

print("Original Tensor:")
print(x)

print("\nTransposed Tensor:")
print(y)

输出

Original Tensor:
tensor([[1, 2, 3],
        [4, 5, 6]])

Transposed Tensor:
tensor([[1, 4],
        [2, 5],
        [3, 6]])

在这个例子中,我们交换了张量 x 的第 0 维(行)和第 1 维(列),所以得到的张量 y 是一个 3x2 的张量。

2. permute 函数

permute 函数可以重新排列张量的所有维度。与 transpose 仅能交换两个维度不同,permute 允许你指定任意的维度顺序。

语法
torch.permute(input, dims)
  • input:输入的张量。
  • dims:一个包含维度索引的元组,表示新的维度顺序。
返回值

返回一个新的张量,维度顺序根据 dims 进行调整。

示例
import torch

# 创建一个 2x3x4 的张量
x = torch.randn(2, 3, 4)

# 调整维度顺序
y = x.permute(2, 0, 1)

print("Original Tensor Shape:", x.shape)
print("Permuted Tensor Shape:", y.shape)

输出

Original Tensor Shape: torch.Size([2, 3, 4])
Permuted Tensor Shape: torch.Size([4, 2, 3])

在这个例子中,原始张量的形状是 (2, 3, 4),我们通过 permute 调整维度顺序为 (4, 2, 3)

详细说明
  • transpose 只交换两个维度。
  • permute 可以自由地重新排列所有维度。例如,x.permute(2, 0, 1) 将维度 201 进行了交换。

总结

  • transpose(dim0, dim1) 用于交换张量的两个维度,适用于二维及以上的张量。
  • permute(dims) 可以重新排列张量的所有维度,适用于任意维度的张量。

你可能感兴趣的:(PyTorch基础,transpose,permute,调整张量维度,pytorch,python)