首先从数据的排列来看,从一个简单的例子来看:
a = np.linspace(1, 24, 24).reshape(2, 3, 4)
print(a)
# 输出结果:
[[[ 1. 2. 3. 4.]
[ 5. 6. 7. 8.]
[ 9. 10. 11. 12.]]
[[13. 14. 15. 16.]
[17. 18. 19. 20.]
[21. 22. 23. 24.]]]
首先 1 - 24 的24个数字,排列成 shape 为 (2, 3, 4) 的数组,可以看到这种数据排列的方式显然是按照从最低维(这里是一个三维数组,0,1,2)2开始排列数据,这样至少确定了在进行reshape的时候数据是如何填充的,我的猜想是reshape可以看作是两步:
举个例子来看非常明显:
a = np.linspace(1, 24, 24).reshape(2, 3, 4)
print(a[1][0][0])
print(a[0][1][0])
print(a[0][0][1])
# 可以看到输出分别为:
13.0
5.0
2.0
也就是最低维是行(向下增加),第二维是列,第三维是深。
Numpy 中的数组 array 也就是对应 Pytorch 的 tensor ,那么上面的方式应该在 pytorch 中也是适用的,测试:
b = torch.from_numpy(a)
print(b)
print(b[1][0][0])
print(b[0][1][0])
print(b[0][0][1])
# 输出为:
tensor([[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]],
[[13., 14., 15., 16.],
[17., 18., 19., 20.],
[21., 22., 23., 24.]]], dtype=torch.float64)
tensor(13., dtype=torch.float64)
tensor(5., dtype=torch.float64)
tensor(2., dtype=torch.float64)
测试一下 Pytorch 的 reshape 的函数,Pytorch 中经常使用 view()函数来进行改变数组的 size:
c = b.view(4, 3, 2)
print(c)
# 输出结果:
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]], dtype=torch.float64)
确实与 Numpy 的 reshape 方式一致
这是一对操作相反的函数,分别用于 降维/升维 操作,例如:
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a.size())
print(a)
b = a.unsqueeze(0)
print(b.size())
print(b)
# 输出结果为:
torch.Size([2, 3])
tensor([[1, 2, 3],
[4, 5, 6]])
torch.Size([1, 2, 3])
tensor([[[1, 2, 3],
[4, 5, 6]]])
两个函数都需要传入 dim
参数,指定在哪一个维度上进行 压缩 或者 升维,上面的例子中可以看到在第0维上增加了一个维度。当然这对数据本身并没有什么改变,但是 Pytorch 中对输入的数据格式都做了要求,往往要求是一个四元组,所以很多时候都要对原始数据进行一些改变,所以就用得到这个函数了。
与之对应:
c = b.squeeze(0)
print(c.size())
print(c)
# 输出:
torch.Size([2, 3])
tensor([[1, 2, 3],
[4, 5, 6]])
不过因为是压缩,所以只有在维度为1的时候才生效。
从函数名就可以看出这个函数用于重新排列,这与 reshape 之类的函数区别就在于这是是直接进行重新排列,例如:
对于 [[1, 2],
[3, 4],
[5, 6]]
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = a.permute(1, 0)
print(b)
# 输出:
tensor([[1, 3, 5],
[2, 4, 6]])
可以看到这是相当于进行了一个转置,也就是维度上的重新排列,如果得到 (2, 3) 的数组,使用view()函数则是:
c = a.view(2, 3)
print(c)
# 输出:
tensor([[1, 2, 3],
[4, 5, 6]])
可以很明显地看到区别。
对于二维的数据,这个函数就是简单地转置,但是对于三维数据,有点抽象,举个例子来说:
a = torch.arange(1, 25)
a = a.view(2, 3, 4)
print(a)
b = a.permute(2, 1, 0)
print(b)
# 输出结果:
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]])
tensor([[[ 1, 13],
[ 5, 17],
[ 9, 21]],
[[ 2, 14],
[ 6, 18],
[10, 22]],
[[ 3, 15],
[ 7, 19],
[11, 23]],
[[ 4, 16],
[ 8, 20],
[12, 24]]])
这里相当于将所有的维度进行了相反的排列,看数据的排列就会发现这就像是矩阵的立体结构进行了旋转,但是内部的数据排列又要改变位置,总之数据要符合原来的排列顺序,这个也称它为高维矩阵的转置,也就是熟知的二维数据转置的推广。(这已经超越我语言的极限了…╮(╯_╰)╭),基本就是这个意思。
更详细的可以参考:Pytorch之permute函数
transpose()函数是经常使用的转置函数,用于。。。转置 ( ̄_ ̄|||)
但是与上面提到的 permute 函数不同的是, transpose 函数只能在两个维度上进行转置,所以如果某次操作涉及三个以及更多维度的转置的话,使用 permute 函数更方便,当然,连续使用多次 transpose 函数也可以达到同样的目的,例如:
a = torch.randn(2, 3, 4)
# transpose() 函数可以进行两个维度的转置 所以连续使用多次transpose()函数等同于对应的permute()函数
b = a.permute(1, 2, 0)
c = a.transpose(0, 1)
c = c.transpose(1, 2)
print(b.equal(c))
print(b.size())
print(c.size())
输出:
True
torch.Size([3, 4, 2])
torch.Size([3, 4, 2])
可以看到,两种函数实现了一样的结果。