【深度学习】torch.squeeze()移除维度函数 | torch.unsqueeze()增加某一维度函数 | pytorch

文章目录

  • 前言
  • 一、torch.squeeze()函数
  • 二、torch.unsqueeze()函数


前言

这两个函数在pytorch框架下的深度学习经常用到,这次把它们记录一下。

一、torch.squeeze()函数

torch.squeeze()用来“挤”掉某一个维度为1的维度,或者所有维度为1的维度。(只挤掉维度为1的维度)
例子如下:

import torch
A=torch.rand(1,3,224,224)
B=torch.unsqueeze(A,dim=0)
print(B.shape)

结果:
【深度学习】torch.squeeze()移除维度函数 | torch.unsqueeze()增加某一维度函数 | pytorch_第1张图片
一般来说,这个函数多用于最后网络输出图片的可视化。
如果对维度不为1的维度进行去除:

import torch
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=1)
print(B.shape)
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=2)
print(B.shape)
A=torch.rand(1,3,224,224)
B=torch.squeeze(A,dim=3)
print(B.shape)

【深度学习】torch.squeeze()移除维度函数 | torch.unsqueeze()增加某一维度函数 | pytorch_第2张图片
不会发生变化

二、torch.unsqueeze()函数

torch.unsqueeze()函数用来插入新的维度扩充张量。例子如下:
在第0维度增加一个维度大小为1的维度(也就是在最前面加一个1)

import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=0)
print(B.shape)

结果为:(这个一般用的最多,比如输入的VGG的照片是1,3,224,224.一般的三通道照片是3,224,224,这时就需要用unsqueeze函数)
【深度学习】torch.squeeze()移除维度函数 | torch.unsqueeze()增加某一维度函数 | pytorch_第3张图片
在第1,2,3维度增加一个维度大小为1的维度,只需要把dim改改就行

import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=1)
print(B.shape)
import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=2)
print(B.shape)
import torch
A=torch.rand(3,224,224)
B=torch.unsqueeze(A,dim=3)
print(B.shape)

【深度学习】torch.squeeze()移除维度函数 | torch.unsqueeze()增加某一维度函数 | pytorch_第4张图片

你可能感兴趣的:(pytorch,深度学习,pytorch,python)