transpose在numpy和torch中的不同

numpy和torch中transpose的功能不同

numpy.ndarray.transpose的官方文档
torch.transpose的官方文档

numpy.transpose需指定一个新的axis的顺序

import numpy as np
array_ = np.arange(24).reshape(1,2,3,4)
print(array_)
# [[[[ 0  1  2  3]
#    [ 4  5  6  7]
#    [ 8  9 10 11]]
# 
#   [[12 13 14 15]
#    [16 17 18 19]
#    [20 21 22 23]]]]

array_t = array_.transpose([0,3,2,1])
print(array_t)
# [[[[ 0 12]
#    [ 4 16]
#    [ 8 20]]
# 
#   [[ 1 13]
#    [ 5 17]
#    [ 9 21]]
# 
#   [[ 2 14]
#    [ 6 18]
#    [10 22]]
# 
#   [[ 3 15]
#    [ 7 19]
#    [11 23]]]]

array_tt = array_.transpose([0,3,2,1])
print(array_tt)
# [[[[ 0  1  2  3]
#    [ 4  5  6  7]
#    [ 8  9 10 11]]
# 
#   [[12 13 14 15]
#    [16 17 18 19]
#    [20 21 22 23]]]]

torch.transpose仅交换两个特定的axis

import torch
import numpy as np
tensor_ = torch.from_numpy(np.arange(24).reshape(1,2,3,4))
print(tensor_)
# tensor([[[[ 0,  1,  2,  3],
#           [ 4,  5,  6,  7],
#           [ 8,  9, 10, 11]],
# 
#          [[12, 13, 14, 15],
#           [16, 17, 18, 19],
#           [20, 21, 22, 23]]]])

tensor_t = tensor_.transpose(1, 3)
print(tensor_t)
# tensor([[[[ 0, 12],
#           [ 4, 16],
#           [ 8, 20]],
# 
#          [[ 1, 13],
#           [ 5, 17],
#           [ 9, 21]],
# 
#          [[ 2, 14],
#           [ 6, 18],
#           [10, 22]],
# 
#          [[ 3, 15],
#           [ 7, 19],
#           [11, 23]]]])

tensor_tt = tensor_t.transpose(1, 3)
print(tensor_tt)
# tensor([[[[ 0,  1,  2,  3],
#           [ 4,  5,  6,  7],
#           [ 8,  9, 10, 11]],
# 
#          [[12, 13, 14, 15],
#           [16, 17, 18, 19],
#           [20, 21, 22, 23]]]])

在numpy中使用torch的语法会导致ValueError

import numpy as np
array_ = np.arange(24).reshape(1,2,3,4)
array_t = array_.transpose(1,3)
# Traceback (most recent call last):
#   File "", line 1, in 
# ValueError: axes don't match array

你可能感兴趣的:(遇过的坑,pytorch,深度学习,python,numpy)