Pytorch张量转置维度该怎么写?

今天在看代码时发现torch里的转置参数写成

transpose(1, 0)

transpose(0, 1)

结果是一样的

即:

import torch
x = torch.randn(3, 2)
print(x.transpose(1, 0).shape)
print(x.transpose(0, 1).shape)
# 运行结果
# torch.Size([2, 3])
# torch.Size([2, 3])

当时我就不理解,当时的思维定势是觉得print(x.transpose(0, 1).shape)结果应该是没对张量进行操作时的形状,后来经过实验才发现,这个参数只能传入两个,就是这两个维度互换位置

接着,我又进行了别的实验,比如三维张量

import torch
x = torch.randn(3, 2, 4)
print(x.transpose(2, 0).shape)
print(x.transpose(0, 2).shape)
# 运行结果
# torch.Size([4, 2, 3])
# torch.Size([4, 2, 3])

我又看了numpy里的transpose函数,其用法和上面的transpose有点区别,对比着写了测试代码如下:

import numpy as np
import torch
x = torch.randn(3, 2, 4)
print(x.transpose(2, 1).shape)
print(x.transpose(1, 2).shape)
print(np.transpose(x, [2, 0, 1]).shape)
print(np.transpose(x, [0, 2, 1]).shape)
print(np.transpose(x, [0, 1, 2]).shape)
# 运行结果
# torch.Size([3, 4, 2])
# torch.Size([3, 4, 2])
# torch.Size([4, 3, 2])
# torch.Size([3, 4, 2])
# torch.Size([3, 2, 4])

以上仅为个人学习记录,若有理解不当之处,欢迎批评指正!!!

你可能感兴趣的:(pytorch)