transpose的含义及用法(numpy和pytorch)

 

numpy中的transpose

Permute the dimensions of an array.

Parameters
----------
a : array_like
    Input array.
axes : list of ints, optional
    By default, reverse the dimensions, otherwise permute the axes
    according to the values given.

Returns
-------
p : ndarray
    `a` with its axes permuted.  A view is returned whenever
    possible.

transpose可以理解为矩阵转置概念在高维数组上的扩展。

第一个参数为需要进行transpose操作的ndarry,axes为调整后的坐标轴顺序。

举个例子,

import numpy as np

x=np.random.rand(3,2,5,6)

x有4个dim,分别为3,2,5,6,它们的编号为0,1,2,3.

np.transpose(x,(0,2,1,3)).shape

#(3, 5, 2, 6)

得到的结果为(3,5,2,6)

新的轴号 该轴上的长度 原来的轴号
0 3 0
1 5 2
2 2 1
3 6 3

 

表格中的第三列就是我们传进去的参数axes。

一般numpy都会给方法封装两种调用接口

  1. 顶层调用
    np.transpose(x,(0,2,1,3))

     

  2. 成员函数调用
    x.transpose(0,2,1,3)

     

这两种用法是等价的。

 

pytorch中的transpose

from torch import Tensor

z=Tensor(x)
#z.shape is (3, 2, 5, 6)

z.transpose(2,3).shape

#torch.Size([3, 2, 6, 5])

该transpose函数只接受两个参数,指示了两个轴的编号,它的作用是调换这两个轴,在例子中就是调换最后两个轴。

顶层调用不再赘述。

 

你可能感兴趣的:(学习pytorch,numpy,python)