今天在想把一个三维的[3, x, y]
的tensor
转为[x, y, 3]
遇到一些问题,最后的解决方法是把tensor
转为numpy
,然后使用numpy.transpose(mytensor, [1, 2, 0])
解决。因此分析一下torch和numpy中的transpose函数。
用法如下:
numpy.transpose(a, axes=None)
a = np.random.randint(0, 10, (3, 2))
print(a, a.shape)
a = np.transpose(a)
print(a, a.shape)
a = np.transpose(a, (1, 0))
print(a, a.shape)
结果为:
[[5 7]
[2 3]
[9 1]]
(3, 2)
[[5 2 9]
[7 3 1]]
(2, 3)
[[5 7]
[2 3]
[9 1]]
(3, 2)
可以看到,每次都是将该二维矩阵转置。
a = np.random.randint(0, 10, (3, 2, 4))
print(a, a.shape)
a = np.transpose(a, (1, 2, 0))
print(a, a.shape)
结果:
[[[4 3 6 8]
[7 0 1 1]]
[[0 2 6 4]
[4 0 6 2]]
[[3 3 4 6]
[5 6 6 2]]]
(3, 2, 4)
[[[4 0 3]
[3 2 3]
[6 6 4]
[8 4 6]]
[[7 4 5]
[0 0 6]
[1 6 6]
[1 2 2]]]
(2, 4, 3)
可以看到,本来矩阵的形状是pre = [3, 2, 4]
,变换时传入的参数是(1, 2, 0)
,之后矩阵就变成了[2, 4, 3]
也就是[pre[1], pre[2], pre[0]]
。
但是具体的变换过程还没搞懂。
axes
,默认是:range(a.ndim)[::-1]
,也就是原来矩阵shape
的逆序。a = np.random.randint(0, 10, (3, 2, 4))
print(a, a.shape)
b = a.transpose([1, 2, 0])
print(b, b.shape)
结果:
[[[5 0 8 0]
[5 4 5 8]]
[[6 8 8 8]
[8 0 4 5]]
[[2 6 8 3]
[0 8 3 4]]]
(3, 2, 4)
[[[5 6 2]
[0 8 6]
[8 8 8]
[0 8 3]]
[[5 8 0]
[4 0 8]
[5 4 3]
[8 5 4]]]
(2, 4, 3)
效果是一样的。
a = np.random.randint(0, 10, (2, 4))
print(a)
b = a.transpose()
print(b)
结果:
[[6 5 4 8]
[3 2 1 2]]
[[6 3]
[5 2]
[4 1]
[8 2]]
b[0][0] = 15
print(a)
print(b)
结果:
[[15 5 4 8]
[ 3 2 1 2]]
[[15 3]
[ 5 2]
[ 4 1]
[ 8 2]]
一个改变,另一个也改变。
用法如下:
torch.transpose(input, dim0, dim1)
返回一个tensor
,是input
的转置。并且同样是共享一个实际tensor,一个改变另一个也改变。
import torch
a = torch.randint(0, 10, (2, 4))
print(a)
b = torch.transpose(a, 1, 0)
print(b)
c = torch.transpose(a, 0, 1)
print(c)
结果:
tensor([[2, 7, 0, 9],
[8, 2, 8, 7]])
tensor([[2, 8],
[7, 2],
[0, 8],
[9, 7]])
tensor([[2, 8],
[7, 2],
[0, 8],
[9, 7]])
可以看到,都是在进行矩阵的转置。
在本函数中,dim0
和dim1
会互换(转置),所以transpose(a, 1, 0)
和transpose(a, 0, 1)
效果一致,都是dim[0]和dim[1]互换。这与numpy的函数不同。
a = torch.randint(0, 10, (2, 3, 4))
print(a)
b = torch.transpose(a, 1, 2)
print(b)
结果:
tensor([[[2, 0, 6, 7],
[8, 8, 0, 2],
[6, 7, 6, 6]],
[[9, 1, 6, 4],
[8, 3, 2, 8],
[0, 0, 4, 9]]])
tensor([[[2, 8, 6],
[0, 8, 7],
[6, 0, 6],
[7, 2, 6]],
[[9, 8, 0],
[1, 3, 0],
[6, 2, 4],
[4, 8, 9]]])