pytorch中的squeeze和unsqueeze函数的使用

应用场景:当我们进行深度学习使用Image函数导入图片时,默认它的维度为[C, H, W],此时根据模型的需要导入batch这一维度。
部分程序
# 导入要测试的图像(自己找的,不在数据集中),放在源文件目录下
im = Image.open('dog.jpg')
im = transform(im)  # [C, H, W]
im = torch.unsqueeze(im, dim=0)      # 对数据增加一个新维度(batch),因为tensor的参数是[batch, channel, height, width]

unsqueeze 指定哪个下标就增加那个下标长度为1的维度(必须加上下标,)。

squeeze 从中删除长度为1的维度。指定下标就删除那个下标长度为1的维度,若不为1则不删除。

具体详情见如下代码:

unsquenze(增加长度1的维度):

import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape)        # 查看a的形状
a = a.unsqueeze(0)    
print(a)              # 经过unsqueeze函数后的情况
print(a.shape)        # 维度变化情况

# 结果输出
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0., 1., 2.],
         [3., 4., 5.],
         [6., 7., 8.]]], dtype=torch.float64)
torch.Size([1, 3, 3])
import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape)        # 查看a的形状
a = a.unsqueeze(1)    # 等同于a = a.unsqueeze(-2)
print(a)              # 经过unsqueeze函数后的情况
print(a.shape)        # 维度变化情况

# 结果输出
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0., 1., 2.]],

        [[3., 4., 5.]],

        [[6., 7., 8.]]], dtype=torch.float64)
torch.Size([3, 1, 3])

​
import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape)        # 查看a的形状
a = a.unsqueeze(2)    # 等同于a = a.unsqueeze(-1)
print(a)              # 经过unsqueeze函数后的情况
print(a.shape)        # 维度变化情况

# 结果输出
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0.],
         [1.],
         [2.]],

        [[3.],
         [4.],
         [5.]],

        [[6.],
         [7.],
         [8.]]], dtype=torch.float64)
torch.Size([3, 3, 1])

squeeze(减少长度为1的维度):

import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape)        # 查看s的形状
a = a.squeeze()       # 等同于a = a.unsqueeze(-1)         
print(a)              # 经过squeeze函数后的情况      
print(a.shape)        # 维度变化情况


# 结果
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
   
  
   
     
import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape)        # 查看s的形状
a = a.squeeze(0)      # 此时删除a的下标为0的维度长度为1         
print(a)              # 经过squeeze函数后的情况      
print(a.shape)        # 维度变化情况
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[0., 1., 2.]],

        [[3., 4., 5.]],

        [[6., 7., 8.]]], dtype=torch.float64)
torch.Size([3, 1, 3])
import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape)        # 查看s的形状
a = a.squeeze(2)      # 此时删除a的下标为2的维度长度为1(等价于a = a.squeeze(-2))       
print(a)              # 经过squeeze函数后的情况      
print(a.shape)        # 维度变化情况

# 结果
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[0., 1., 2.],
         [3., 4., 5.],
         [6., 7., 8.]]], dtype=torch.float64)
torch.Size([1, 3, 3])

import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape)        # 查看s的形状
a = a.squeeze(1)      # 此时小标为1或3时,对应的维度长度均为3,所以squeeze没有起作用        
print(a)              # 经过squeeze函数后的情况      
print(a.shape)        # 维度变化情况

# 结果
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[[0., 1., 2.]],

         [[3., 4., 5.]],

         [[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])

你可能感兴趣的:(python,pytorch,python,深度学习)