【Python学习】transpose函数

shape:(batch_size * x * y )
有batch_size个二维矩阵(x * y)相当于(z * x * y)

1. 多维数组的索引

import numpy as np
# 创建
x = np.arange(12).reshape((2,2,3))
print(x)

# 得到三维数组
[[[ 0  1  2]
  [ 3  4  5]]

 [[ 6  7  8]
  [ 9 10 11]]]
# 相当于 batch_size*2*3

怎么读索引
上面的三维数组相当于下图
【Python学习】transpose函数_第1张图片

对于第一个括号,相当于说这个三维数组有两个2维数组
对于里层的第二个括号,就是最常见的二维矩阵
对于‘7’,x,y为二维数组中的1,0(索引从0开始,横向是x轴,纵向是y轴,与shape中的x,y有所不同),z是代表‘z’在第2个二维矩阵中。所以‘7’的索引是(1,0,1)(x, y, z)

2. numpy的transpose函数
按轴交换
transpose函数中的两个参数是要互换的轴

import numpy as np
# 创建
x = np.arange(12).reshape((2,2,3))
print(x)
#输出
[[[ 0  1  2]
  [ 3  4  5]]

 [[ 6  7  8]
  [ 9 10 11]]]
# transpose(z,x,y)-> (x,y,z)
y = np.transpose((x),(1,2,0))
print(y)
#输出
[[[ 0  6]
  [ 1  7]
  [ 2  8]]

 [[ 3  9]
  [ 4 10]
  [ 5 11]]]

3. Pytorch的transpose函数

transpose只能对两个维度进行转换

import numpy as np
import torch
# 创建
x = np.arange(12).reshape((2,2,3))
x = torch.Tensor(x)
print(x)
# 输出
tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 9., 10., 11.]]])
# transpose:转换第一维度和第二维度,即二维矩阵的x和y
a = x.transpose(1,2)
print(a)
#输出
tensor([[[ 0.,  3.],
         [ 1.,  4.],
         [ 2.,  5.]],

        [[ 6.,  9.],
         [ 7., 10.],
         [ 8., 11.]]])
# transpose:
b = x.transpose(0,2)
print(b)
# 输出
tensor([[[ 0.,  6.],
         [ 3.,  9.]],

        [[ 1.,  7.],
         [ 4., 10.]],

        [[ 2.,  8.],
         [ 5., 11.]]])

你可能感兴趣的:(Python编程,python,pytorch)