torch.transpose()函数torch.permute()函数解读

目录

1 官网文档

transpose()

permute()

2 使用及对比

使用: 

操作的dim不同:

维度数顺序影响:

合法性不同:

相同点:

再使用view()和reshape()时


在pytorch中转置用的函数就只有这两个

1.transpose()

2.permute()

1 官网文档

transpose()

http://torch.transpose — PyTorch 1.10.0 documentation

torch.transpose(inputdim0dim1) → 再使用view()和reshape()时 

 函数返回输入矩阵input的转置,交换维度dim0和dim1

参数:

  1. input(Tensor)-输入张量,必填
  2. dim0(int)-转置的第一维,默认0,可选
  3. dim1(int)-转置的第二维,默认1,可选

注意只能有两个相关的交换的位置参数。

permute()

http://torch.permute — PyTorch 1.10.0 documentation

torch.permute(inputdims) → Tensor

参数:

        dims(int...*)-换位顺序,必填

2 使用及对比

使用: 

# 创造二维数据x   dim0 = 2, dim1 = 3
x = torch.randn(2, 3)   # x.shape = (2, 3)
# 创造三维数据y   dim0 = 2, dim1 = 3, dim2 = 4
y = torch.randn(2, 3, 4)   # y.shape = (2, 3, 4)

 对于transpose

x.transpose(0, 1)   # x.shape = (3, 2)
x.transpose(1, 0)   # x.shape = (3, 2)
y.transpose(0, 1)   # y.shape = (3, 2, 4)
y.transpose(0, 2, 1)   # error 操作不了多维

对于permute()

x.permute(0, 1)   # x.shape = (2, 3)
x.permute(1, 0)   # x.shape = (3, 2)
y.permute(0, 1)   # error 没有传入所有维度数
y.permute(1, 0, 2)   # y.shape = (3, 2, 4)

操作的dim不同:

  1. transpose()一次只能操作两个维度;
  2. permute()一次可以操作多个维度,但是必须传入所有维度,因为permute()的参数是int*。 

维度数顺序影响:

  1. transpose()中的dim没有维度数的大小之分(0,1)和(1,0)是相同结果;
  2. permute()中的dim有维度数的大小之分(0,1)和(1,0)是不一样的操作。

合法性不同:

  1. torch.transpose(x)合法,x.transpose()合法
  2. tensor.permute(x)不合法,x.permute()合法

相同点:

  1. 都是返回转置后的矩阵
  2. 都可以操作高位矩阵,permute在高维的功能性更强。 

再使用view()和reshape()时

view()函数改变转置后的数据结构,会导致报错。这是因为tensor转置后数据的内存地址不连续,也就是tensor.is_contiguous()==False

x.torch.rand(3, 4)
x = x.transpose(0, 1)
print(x.is_contiguous())   # 是否连续 --> 'False'
#  发现
x.view(3, 4)   # view会报错
# 但是这样是可以的
x = x.contiguous()
x.view(3, 4)

虽然在torch里面,view函数相当于numpy的reshape,但是这时候reshape()可以改变该tensor的结构,但是view()不可以。

x = torch.rand(3, 4)
x = x.permute(1, 0)   # 等价于x = x.transpose(0, 1)
x.reshape(3, 4)   # reshape不报错

说明x.reshape(3, 4)这个操作等于x = x.contiguous().view(),尽管如此,但是我们还是不推荐使用reshape,除非为了获取完全不同但是数据相同的克隆体。 

 

 

你可能感兴趣的:(深度学习,pytorch,深度学习,人工智能)