【Pytorch入门】Tensor维度变换

Pytorch学习笔记——Tensor维度变换

view()/reshape()

torch.view(arg1,arg2....)

作用:类似于numpy中的resize()的功能,重构tensor的维度,返回一个有相同数据但不同大小的tensor

简单来讲,就是把原先tensor中的数据按照行优先的顺序排成一个一维数组,然后根据参数组合成其他维度的tensor

example:

首先我们构造两个张量a([1,2,3,4,5,6]),b([[1,2,3],[4,5,6]]),对其执行view操作后,输出结果

import torch
a=torch.tensor([1,2,3,4,5,6])
print("变换前a的shape为==>",a.shape)
print("变换后a的shape为==>",a.view(1,6).shape)
print("变换前的a:",a)
print("变换后的a:",a.view(1,6))
print("-"*40)
b=torch.tensor([[1,2,3],[4,5,6]])
print("变换前b的shape为==>",b.shape)
print("变换后b的shape为==>",b.view(1,6).shape)
print("变换前的b:",b)
print("变换后的b:",b.view(1,6))

【Pytorch入门】Tensor维度变换_第1张图片

从结果中,我们可以看出,无论是a([1,2,3,4,5,6])还是b([[1,2,3],[4,5,6]]),由于排成一维时均为6个元素,所以只要view中参数一致,其结果也一致

另外,在view中还要注意view(-1)以及view(arg1,-1)这两种情况

example

以一个shape为(3,2)的tensor为例,我们对其执行view(-1),对比观察结果

import torch
a=torch.rand((3,2))
print(a)
print(a.view(-1))

【Pytorch入门】Tensor维度变换_第2张图片

通过上述例子我们可以得知,通过view(-1)我们可以将原tensor转换成一维的结构

import torch
a=torch.rand((3,2))
print(a)
print(a.view(2,-1))

【Pytorch入门】Tensor维度变换_第3张图片

由上面的例子可以看到,如果是torch.view(arg1,-1),则表示在arg2未知,arg1已知的情况下自动补齐列向量长度,在这个例子中arg1=2,tensor a总共由6个元素,则arg2=6/2=3。

注意:数据的存储,维度顺序非常重要,需要记住数据的顺序的实际意义

unsqueeze()

torch.unsqueeze(input,dim,out=None)

作用:扩展维度,对输入的既定位置增加维数为1的维度,返回一个tensor

example:

首先,我们先生成一个shape为(2,3)的向量

import torch
t=torch.rand((2,3))
print(t)
print(t.shape)

【Pytorch入门】Tensor维度变换_第4张图片

然后插入第 0 维,插入后向量的 shape 为(1,2,3),在倒数第一个维度上增加一个维度,插入后向量的shape为(2,3,1)

t.unsqueeze(0)
t.unsqueeze(0).shape

【Pytorch入门】Tensor维度变换_第5张图片

t.unsqueeze(-1)
t.unsqueeze(-1).shape

【Pytorch入门】Tensor维度变换_第6张图片

squeeze()

torch.squeeze(input, dim=None, out=None)

作用:压缩维度,删除input中大小为1的所有维,如果给定dim,则只在给定的维度上进行压缩操作,返回一个tensor

example:

首先我们先得到一个shape为(2,1,2,1)的tensor

t=torch.zeros(2,1,2,1)
print(t.shape)

在这里插入图片描述

由图中可以看出t的维度为(2,1,2,1)

接下来我们使用squeeze()来压缩维度

y=torch.squeeze(t)
print(y.shape)

在这里插入图片描述

我们也可以指定dim,举个例子,我们将倒数第一维去掉得到新的tensor,输出其shape,此时由于倒数第一维的维数为1因此可以被去掉

z=torch.squeeze(t,-1)
print(z.shape)

在这里插入图片描述

如果维数不为1则没有变化,如我们尝试去掉第0维

y=torch.squeeze(t,0)
print(y.shape)

在这里插入图片描述

由图中可以看出,新得到的tensor的shape并没有发生改变,因为只有维度为1时才会去掉

expand()

作用:将tensor广播到新的形状

注意:只能对维数为1的维度进行扩展,且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回

example

import torch
a=torch.rand((1,1,3))
#-1表示保持原来的维数不变
b=a.expand(2,2,-1)
print("构造的tensor a :",a)
print("变换后的tensor b:",b)
print("-"*40)
print("tensor a 的形状:",a.shape)
print("tensor b 的形状:",b.shape)

【Pytorch入门】Tensor维度变换_第7张图片

repeat()

作用:沿着特定的维度重复这个张量,和*expand()*不同的是,这个函数拷贝张量的数据。

example

import torch
a=torch.rand((1,1,3))
b=a.repeat(2,2,1)
c=a.repeat(2,2,3)
print(a.shape)
print(b.shape)
print(c.shape)

【Pytorch入门】Tensor维度变换_第8张图片
通过观察上述代码,我们可以简单了解一下repeat
【Pytorch入门】Tensor维度变换_第9张图片如果无法理解的话,可以借助一下输出的结果去理解
【Pytorch入门】Tensor维度变换_第10张图片

关于为什么要进行维度变换

在深度学习中很多都是矩阵运算,为了满足矩阵乘法我们需要变换维度

你可能感兴趣的:(pytorch)