pytorch中张量的维度变换,torch.squeeze()、torch.unsqueeze()函数

张量维度变换

通过.reshape方法,能够灵活调整张量的形状。而在实际操作张量进行计算时,往往需要另外进行降维和升维的操作,当我们需要除去不必要的维度时,可以使用squeeze函数,而需要手动升维时,则可采用unsqueeze函数。

1.torch.squeeze()函数:删除不必要的维度,提出了shape返回结果中的1

(1)torch.squeeze()函数的使用

t = torch.zeros(1, 1, 3, 1)  #创建4维张量,由1个三维张量组成,每个三维张量由1个二维张量组成,每个二维张量由3个一维张量组成,每个一维张量有1个元素零(浮点型)组成。
#结果为:tensor([[[[0.],
                   [0.],
                   [0.]]]])
t.shape
#结果为:torch.Size([1, 1, 3, 1])
t张量解释:一个包含一个三维的四维张量,三维张量只包含一个三行一列的二维张量。
torch.squeeze(t) #降维
#结果为:tensor([0., 0., 0.])
torch.squeeze(t).shape
#结果为:torch.Size([3])  #一个数表示一维张量,含有3个元素。转化后生成了一个一维张量  
          
t1 = torch.zeros(1, 1, 3, 2, 1, 2)
t1.shape
#结果为:torch.Size([1, 1, 3, 2, 1, 2])
torch.squeeze(t1)    #降维                 
#结果为:tensor([[[0., 0.],
                  [0., 0.]],

                 [[0., 0.],
                  [0., 0.]],

                 [[0., 0.],
                  [0., 0.]]])
torch.squeeze(t1).shape
#结果为:torch.Size([3, 2, 2])                 

简单理解:squeeze就相当于提出了shape返回结果中的1。

2. torch.unsqueeze()函数:手动升维,维度增加是从前往后数,在shape返回序列的第几个位置插入1

(1) torch.unsqueeze()函数的使用

t = torch.zeros(1, 2, 1, 2)#创建4维张量,由1个三维张量组成,每个三维张量由2个二维张量组成,每个二维张量由1个一维张量组成,每个一维张量有2个元素零(浮点型)组成。
t.shape 
#结果为:torch.Size([1, 2, 1, 2])
torch.unsqueeze(t, dim = 0)   # 在第1个维度索引上升高1个维度
#结果为:tensor([[[[[0., 0.]],

                   [[0., 0.]]]]])

torch.unsqueeze(t, dim = 0).shape
#结果为:torch.Size([1, 1, 2, 1, 2])

torch.unsqueeze(t, dim = 2).shape   # 在第3个维度索引上升高1个维度
#结果为:torch.Size([1, 2, 1, 1, 2]) 
注意观察上面的结果
          
torch.unsqueeze(t, dim = 4).shape   # 在第5个维度索引上升高1个维度
#结果为:torch.Size([1, 2, 1, 2, 1]) 
维度增加是从前往后数,在第几个位置插入1              

注意理解:维度和shape返回结果一一对应的关系,shape返回的序列有几个元素,张量就有多少维度。维度增加是从前往后数,在shape返回序列的第几个位置插入1。

你可能感兴趣的:(pytorch深度学习,pytorch,python,深度学习)