PyTorch cat, stack, transpose, permute, view和reshape详解

简述

使用Pytorch过程中,我们经常对torch类型矩阵进行变换,常用的命令较多,我也将常用的命令进行学习整理,欢迎各位小伙伴一起学习,有错误欢迎批评指正!

1. cat

对数据沿着某一维度进行拼接。cat后数据的总维数不变.

比如下面代码对两个2维tensor(分别为2x3,1x3)进行拼接,拼接完后变为3x3还是2维的tensor。

import torch
x = torch.randn(2, 3)
#tensor([[ 0.6614,  0.2669,  0.0617],
#        [ 0.6213, -0.4519, -0.1661]])
y = torch.randn(1, 3)
# tensor([[-1.5228,  0.3817, -1.0276]])
torch.cat((x, y), dim=0) # 0维度拼接
# tensor([[ 0.6614,  0.2669,  0.0617],
#        [ 0.6213, -0.4519, -0.1661],
#        [-1.5228,  0.3817, -1.0276]])

2. stack

增加新的维度进行堆叠, 将若干个张量在dim维度上连接,生成一个扩维的张量,比如说原来你有若干个2维张量,连接可以得到一个3维的张量
stack(tensors,dim=0,out=None)

而stack则会增加新的维度。
如对两个1x2维的tensor在第0个维度上stack,则会变为2x1x2的tensor;在第1个维度上stack,则会变为1x2x2的tensor。

a=torch.rand((1,2))
# tensor([[0.3138, 0.1980]])
b=torch.rand((1,2))
# tensor([[0.4162, 0.2843]])
c=torch.stack((a,b),0)
# tensor([[[0.3138, 0.1980]],
#        [[0.4162, 0.2843]]])
c.shape
# torch.Size([2, 1, 2])
d = torch.stack((a,b),1)
# tensor([[[0.3138, 0.1980],
#         [0.4162, 0.2843]]])
d.shape
# torch.Size([1, 2, 2])

3. transpose

交换指定的两个维度的内容

x = torch.randn(2, 3)
x
# tensor([[ 0.6614,  0.2669,  0.0617],
#        [ 0.6213, -0.4519, -0.1661]])
x.transpose(0, 1)
# tensor([[ 0.6614,  0.6213],
#        [ 0.2669, -0.4519],
#        [ 0.0617, -0.1661]])

4. permute

一次性交换多个维度

x = torch.randn(2,3,4)
x
# tensor([[[ 0.4391,  1.1712,  1.7674, -0.0954],
#         [ 0.1394, -1.5785, -0.3206, -0.2993],
#         [-0.9274,  0.5451,  0.0663, -0.4370]],
#        [[ 0.7626,  0.4415,  1.1651,  2.0154],
#         [ 0.1374,  0.9386, -0.1860, -0.6446],
#         [ 1.5392, -0.8696, -3.3312, -0.7479]]])
x.permute(2, 0, 1)
# tensor([[[ 0.4391,  0.1394, -0.9274],
#         [ 0.7626,  0.1374,  1.5392]],
#        [[ 1.1712, -1.5785,  0.5451],
#         [ 0.4415,  0.9386, -0.8696]],

5. view

transpose和permute是将张量的维度进行变换,而view是将张量拉伸成一维,然后根据传入的维度(也就是想要变换的维度),重构出一个新的张量。

x = torch.randn(2, 3)
x
# tensor([[-0.5631,  0.1103, -2.2590],
#        [ 0.6067, -0.1383,  0.8310]])
x.view(3, -1)
# tensor([[-0.5631,  0.1103],
#        [-2.2590,  0.6067],
#        [-0.1383,  0.8310]])

6. reshape

view方法类似,将输入tensor转换为新的shape格式。
但是reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()

即:在满足tensor连续性条件时a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同。

简单理解,view操作是在原tensor进行的,当我们进行对原tensor操作后,tensor不在满足连续性条件(其实是我们通过view修改了查找规则,但是数据排序并没有改变),如果我们再一次修改不连续的tensor就会出错,解决方法是contiguous,其实就是重新开辟了一块空间,进一步保证了view后的存储方式,进而可以继续修改。而如果使用reshape的话,会直接赋值一个副本,进行操作,就避免了view的情况。因此,reshape方法的鲁棒性会更强。

如果大家还没有理解的话,给推荐一篇博客view和reshape区别

7. unsqueeze 和 squeeze

增加维度和压缩维度,通过例子就会很好!

import torch
x = torch.randn(2,3)
x.shape
# torch.Size([2, 3])
x.unsqueeze(1).shape
# torch.Size([2, 1, 3])
x.unsqueeze(0).shape
# torch.Size([1, 2, 3])
x.unsqueeze(2).shape
# torch.Size([2, 3, 1])

总结

以上就是借助简单的例子学习,总结了常用的矩阵修改的方法,希望可以帮助到大家

你可能感兴趣的:(PyTorch,Python,pytorch,深度学习)