torch中transpose和permute转置问题

在pytorch中转置用的函数就只有这两个
transpose():操作不了多维
permute():可以操作多维

百度搜索的教程中的理论太多太多,不如直接代码测试容易让人理解,话不多说,直接代码尝试:

t = torch.randn(2, 4, 5)   # 首先创建2个正态分布的4*5的矩阵
print(t)
tensor([[[ 1.0224,  0.5716, -1.2172, -0.0534, -1.0312],
         [ 0.0622, -0.0260,  2.6485, -0.9420, -0.1987],
         [-0.6560,  0.0956, -2.2045, -0.6329,  2.3294],
         [-0.0351,  1.0526, -0.1086, -1.1315, -0.2870]],

        [[ 0.0081, -0.5649, -0.4293, -0.4485, -1.5479],
         [-0.0086, -1.2145,  2.0289,  0.5889, -0.2644],
         [ 0.1313,  0.2485, -1.1323, -0.8699,  0.2849],
         [ 0.3727, -0.0079,  0.3927,  1.4980,  0.5328]]])

# randn(2,4,5)中三个数的索引分别为 0,1,2
t1=t.transpose(1,0) #此时transpose中的参数1,0表示交换t中的索引位置,t1也既是(4,2,5)表示4个2*5的矩阵
# t1=t.transpose(0,1)也是和上述一样的意思,交换0,1的位置,结果都是下图结果
print(t1)
tensor([[[ 1.0224,  0.5716, -1.2172, -0.0534, -1.0312],
         [ 0.0081, -0.5649, -0.4293, -0.4485, -1.5479]],

        [[ 0.0622, -0.0260,  2.6485, -0.9420, -0.1987],
         [-0.0086, -1.2145,  2.0289,  0.5889, -0.2644]],

        [[-0.6560,  0.0956, -2.2045, -0.6329,  2.3294],
         [ 0.1313,  0.2485, -1.1323, -0.8699,  0.2849]],

        [[-0.0351,  1.0526, -0.1086, -1.1315, -0.2870],
         [ 0.3727, -0.0079,  0.3927,  1.4980,  0.5328]]])

同理,permute道理相通,只是可以操作多维而已,但必须传入所有维度数。

t = torch.randn(2, 4, 5)   # 首先创建2个正态分布的4*5的矩阵
t
tensor([[[ 0.2464, -1.5848,  0.4432, -0.8214, -1.3044],
         [-0.0355, -0.4341,  0.3624, -1.4011,  0.0111],
         [ 1.3601,  0.1008, -1.4646,  0.2118,  0.1643],
         [ 1.9176, -0.0868,  0.8551,  0.4760, -1.5810]],

        [[ 0.4147, -1.2642,  1.1018,  0.4975, -0.3797],
         [-1.0450,  1.0998, -0.8400,  0.5221,  1.0553],
         [-0.7401,  1.4456,  0.9995, -0.6732, -0.5768],
         [ 1.0525,  0.5885,  1.3591, -0.3551, -1.4941]]])

t1=t.permute(1,0,2)#必须与t中参数数目一致,执行完这句之后含义表示randn(4,2,5) 4个2*5的矩阵
t1
tensor([[[ 0.2464, -1.5848,  0.4432, -0.8214, -1.3044],
         [ 0.4147, -1.2642,  1.1018,  0.4975, -0.3797]],

        [[-0.0355, -0.4341,  0.3624, -1.4011,  0.0111],
         [-1.0450,  1.0998, -0.8400,  0.5221,  1.0553]],

        [[ 1.3601,  0.1008, -1.4646,  0.2118,  0.1643],
         [-0.7401,  1.4456,  0.9995, -0.6732, -0.5768]],

        [[ 1.9176, -0.0868,  0.8551,  0.4760, -1.5810],
         [ 1.0525,  0.5885,  1.3591, -0.3551, -1.4941]]])

t2=t.permute(2,0,1)#表示randn(5,2,4) 5个2*4的矩阵
t2
tensor([[[ 0.2464, -0.0355,  1.3601,  1.9176],
         [ 0.4147, -1.0450, -0.7401,  1.0525]],

        [[-1.5848, -0.4341,  0.1008, -0.0868],
         [-1.2642,  1.0998,  1.4456,  0.5885]],

        [[ 0.4432,  0.3624, -1.4646,  0.8551],
         [ 1.1018, -0.8400,  0.9995,  1.3591]],

        [[-0.8214, -1.4011,  0.2118,  0.4760],
         [ 0.4975,  0.5221, -0.6732, -0.3551]],

        [[-1.3044,  0.0111,  0.1643, -1.5810],
         [-0.3797,  1.0553, -0.5768, -1.4941]]])


你可能感兴趣的:(pytorch,python,深度学习)