在pytorch中转置用的函数就只有这两个
transpose():操作不了多维
permute():可以操作多维
百度搜索的教程中的理论太多太多,不如直接代码测试容易让人理解,话不多说,直接代码尝试:
t = torch.randn(2, 4, 5) # 首先创建2个正态分布的4*5的矩阵
print(t)
tensor([[[ 1.0224, 0.5716, -1.2172, -0.0534, -1.0312],
[ 0.0622, -0.0260, 2.6485, -0.9420, -0.1987],
[-0.6560, 0.0956, -2.2045, -0.6329, 2.3294],
[-0.0351, 1.0526, -0.1086, -1.1315, -0.2870]],
[[ 0.0081, -0.5649, -0.4293, -0.4485, -1.5479],
[-0.0086, -1.2145, 2.0289, 0.5889, -0.2644],
[ 0.1313, 0.2485, -1.1323, -0.8699, 0.2849],
[ 0.3727, -0.0079, 0.3927, 1.4980, 0.5328]]])
# randn(2,4,5)中三个数的索引分别为 0,1,2
t1=t.transpose(1,0) #此时transpose中的参数1,0表示交换t中的索引位置,t1也既是(4,2,5)表示4个2*5的矩阵
# t1=t.transpose(0,1)也是和上述一样的意思,交换0,1的位置,结果都是下图结果
print(t1)
tensor([[[ 1.0224, 0.5716, -1.2172, -0.0534, -1.0312],
[ 0.0081, -0.5649, -0.4293, -0.4485, -1.5479]],
[[ 0.0622, -0.0260, 2.6485, -0.9420, -0.1987],
[-0.0086, -1.2145, 2.0289, 0.5889, -0.2644]],
[[-0.6560, 0.0956, -2.2045, -0.6329, 2.3294],
[ 0.1313, 0.2485, -1.1323, -0.8699, 0.2849]],
[[-0.0351, 1.0526, -0.1086, -1.1315, -0.2870],
[ 0.3727, -0.0079, 0.3927, 1.4980, 0.5328]]])
同理,permute道理相通,只是可以操作多维而已,但必须传入所有维度数。
t = torch.randn(2, 4, 5) # 首先创建2个正态分布的4*5的矩阵
t
tensor([[[ 0.2464, -1.5848, 0.4432, -0.8214, -1.3044],
[-0.0355, -0.4341, 0.3624, -1.4011, 0.0111],
[ 1.3601, 0.1008, -1.4646, 0.2118, 0.1643],
[ 1.9176, -0.0868, 0.8551, 0.4760, -1.5810]],
[[ 0.4147, -1.2642, 1.1018, 0.4975, -0.3797],
[-1.0450, 1.0998, -0.8400, 0.5221, 1.0553],
[-0.7401, 1.4456, 0.9995, -0.6732, -0.5768],
[ 1.0525, 0.5885, 1.3591, -0.3551, -1.4941]]])
t1=t.permute(1,0,2)#必须与t中参数数目一致,执行完这句之后含义表示randn(4,2,5) 4个2*5的矩阵
t1
tensor([[[ 0.2464, -1.5848, 0.4432, -0.8214, -1.3044],
[ 0.4147, -1.2642, 1.1018, 0.4975, -0.3797]],
[[-0.0355, -0.4341, 0.3624, -1.4011, 0.0111],
[-1.0450, 1.0998, -0.8400, 0.5221, 1.0553]],
[[ 1.3601, 0.1008, -1.4646, 0.2118, 0.1643],
[-0.7401, 1.4456, 0.9995, -0.6732, -0.5768]],
[[ 1.9176, -0.0868, 0.8551, 0.4760, -1.5810],
[ 1.0525, 0.5885, 1.3591, -0.3551, -1.4941]]])
t2=t.permute(2,0,1)#表示randn(5,2,4) 5个2*4的矩阵
t2
tensor([[[ 0.2464, -0.0355, 1.3601, 1.9176],
[ 0.4147, -1.0450, -0.7401, 1.0525]],
[[-1.5848, -0.4341, 0.1008, -0.0868],
[-1.2642, 1.0998, 1.4456, 0.5885]],
[[ 0.4432, 0.3624, -1.4646, 0.8551],
[ 1.1018, -0.8400, 0.9995, 1.3591]],
[[-0.8214, -1.4011, 0.2118, 0.4760],
[ 0.4975, 0.5221, -0.6732, -0.3551]],
[[-1.3044, 0.0111, 0.1643, -1.5810],
[-0.3797, 1.0553, -0.5768, -1.4941]]])