Pytorch——维度变换

Operation

    • view/reshape
      • lost dim information
      • 数据污染
    • squeeze/unsequzee
      • 增加维度
        • for instance
        • for example
      • 删减维度
    • expand/repeat
      • expand:broadcasting
      • repeat:memory copied
    • transpose/t/permute
      • .t( )
      • transpose交换维度
      • permute

view/reshape

lost dim information

view操作必须有物理意义:对四张图片合在一起,忽略上下左右位移信息、二维信息、通道信息。

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.view(4,28*28)
tensor([[0.9750, 0.1818, 0.7628,  ..., 0.6651, 0.3677, 0.8265],
        [0.0921, 0.1000, 0.8699,  ..., 0.2049, 0.3385, 0.5174],
        [0.6671, 0.3868, 0.1271,  ..., 0.8727, 0.7295, 0.7776],
        [0.4635, 0.9024, 0.6649,  ..., 0.4255, 0.0090, 0.0990]])

打印shape:

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.view(4,28*28).shape
torch.Size([4, 784])

多种理解方式:
①把前三个维度合并,即把所有照片、所有通道、所有行都放到第一维度,变成N。每一个
N都有一个一列一行的数据,一行的数据是28个像素点。只关注一行的数据。

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.view(4*28,28).shape
torch.Size([112, 28])

②把前边两个合并,即把所有照片、所有通道合并在一起,只关注数据

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.view(4*1,28,28).shape
torch.Size([4, 28, 28])

数据污染

b没有按照原来的维度信息存储,只有知道了a额外的维度信息,b才能恢复成a。

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> b = a.view(4,784)
>>> b.view(4,28,28,1)

squeeze/unsequzee

增加维度

道理很简单,就是通过这个函数增加一个维度,但是不会改变数据,只是增加了一个组别。比如a是四维的那么可以增加额范围就是[-5,5)
也就是[-a.dim()-1, a.dim()+1)

正数是在当前位置之前插入,负数是在当前位置之后插入。

对应的插入顺序如下:Pytorch——维度变换_第1张图片
可见0-4完全可以满足任何位置的维度增加,所以尽量不适用负数了。

下边分别是0、-1、4、-5的情况

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.unsqueeze(0).shape
torch.Size([1, 4, 1, 28, 28])

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.unsqueeze(-1).shape
torch.Size([4, 1, 28, 28, 1])

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.unsqueeze(4).shape
torch.Size([4, 1, 28, 28, 1])

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.unsqueeze(-5).shape
torch.Size([1, 4, 1, 28, 28])

注意一下不能等于5,会报错:

>>> import torch
>>> a = torch.rand(4,1,28,28)
>>> a.unsqueeze(5).shaoe
Traceback (most recent call last):
  File "", line 1, in 
IndexError: Dimension out of range (expected to be in range of [-5, 4], but got 5)

for instance

用一个小例子来对比一下正负数的区别,以及只改变理解方式,不改变数据本身:

首先引入一个a,注意它是0维度、它的shape是2

>>> import torch
>>> a = torch.tensor([1.2,2.3])
>>> a.shape
torch.Size([2])

增加-1:

>>> import torch
>>> a = torch.tensor([1.2,2.3])
>>> a.unsqueeze(-1).shape
torch.Size([2, 1])

>>> import torch
>>> a = torch.tensor([1.2,2.3])
>>> a.unsqueeze(-1)
tensor([[1.2000],
        [2.3000]])

增加0:

>>> import torch
>>> a = torch.tensor([1.2,2.3])
>>> a.unsqueeze(0).shape
torch.Size([1, 2])

>>> import torch
>>> a = torch.tensor([1.2,2.3])
>>> a.unsqueeze(0)
tensor([[1.2000, 2.3000]])

for example

要将a和b叠加在一起,那就必须保证它们的维度相同

>>> import torch
>>> a = torch.rand(32)
>>> b = torch.rand(4,32,14,14)
>>> a = a.unsqueeze(0).unsqueeze(1).unsqueeze(2)
>>> a.shape
torch.Size([1, 1, 1, 32])

这是我第一遍运行后的结果,很明显是错误的。原因是没有考虑每一步插入后维度发生的变化,正确的应该是这样:

>>> import torch
>>> a = torch.rand(32)
>>> b = torch.rand(4,32,14,14)
>>> a = a.unsqueeze(0).unsqueeze(2).unsqueeze(3)
>>> a.shape
torch.Size([1, 32, 1, 1])

图示:
Pytorch——维度变换_第2张图片

删减维度

和增加维度类似,可正可负:

>>> import torch
>>> b = torch.rand(1,32,1,1)
>>> b.squeeze(0).shape
torch.Size([32, 1, 1])

>>> import torch
>>> b = torch.rand(1,32,1,1)
>>> b.squeeze(-1).shape
torch.Size([1, 32, 1])

如果没有参数,则会尽可能删除所有能删除的维度:

>>> import torch
>>> b = torch.rand(1,32,1,1)
>>> b.squeeze().shape
torch.Size([32])

32是无法删除的:

>>> import torch
>>> b = torch.rand(1,32,1,1)
>>> b.squeeze(1).shape
torch.Size([1, 32, 1, 1])

expand/repeat

expand:broadcasting

扩展维度,注意只能从1变为N,不能从M变成N

>>> import torch
>>> b = torch.rand(1,32,1,1)
>>> b.expand(4,32,14,14).shape
torch.Size([4, 32, 14, 14])

如果某一维度不需要扩展,可以用-1保持不变。但这里有一个bug,就是如果写成-4,它也会变成-4,但无意义。

>>> import torch
>>> b = torch.rand(1,32,1,1)
>>> b.expand(-1,32,-1,-4).shape
torch.Size([1, 32, 1, -4])

repeat:memory copied

repeat给的参数表示的是每一个维度需要重复的次数:

如果用expand相同的方法扩展得到的结果是这样的:

>>> import torch
>>> b = torch.rand(1,32,1,1)
>>> b.repeat(4,32,14,14).shape
torch.Size([4, 1024, 14, 14])

如想得到正常的(4,32,14,14),应为:

>>> import torch
>>> b = torch.rand(1,32,1,1)
>>> b.repeat(4,1,14,14).shape
torch.Size([4, 32, 14, 14])

这个repeat函数会占用新的内存空间,不建议使用。

transpose/t/permute

.t( )

只能用于二维的,比如矩阵转置:

>>> import torch
>>> a = torch.rand(3,4)
>>> a.t()
tensor([[0.9967, 0.4838, 0.3454],
        [0.7178, 0.7012, 0.2024],
        [0.8568, 0.5522, 0.8467],
        [0.0917, 0.4319, 0.9853]])

transpose交换维度

只能将维度两两交换。直接使用该函数会报错,因为内存数据会被打乱,变得不连续:

>>> import torch
>>> a = torch.rand(4,3,32,32)
>>> a1 = a.transpose(1,3).view(4,3*32*32).view(4,3,32,32)
Traceback (most recent call last):
  File "", line 1, in 
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

需要使用到contiguous()函数

>>> import torch
>>> a = torch.rand(4,3,32,32)
>>> a1 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,3,32,32)
>>> a2 = a.transpose(1,3).contiguous().view(4,3*32*32).view(4,32,32,3).transpose(1,3)
>>> a1.shape,a2.shape
(torch.Size([4, 3, 32, 32]), torch.Size([4, 3, 32, 32]))

a1和a2是不同的。

permute

更方便的一种维度交换方法:

>>> import torch
>>> a = torch.rand(4,3,28,32)
>>> a.permute(0,2,3,1).shape
torch.Size([4, 28, 32, 3])

你可能感兴趣的:(PyTorch)