应用场景:当我们进行深度学习使用Image函数导入图片时,默认它的维度为[C, H, W],此时根据模型的需要导入batch这一维度。 部分程序
# 导入要测试的图像(自己找的,不在数据集中),放在源文件目录下
im = Image.open('dog.jpg')
im = transform(im) # [C, H, W]
im = torch.unsqueeze(im, dim=0) # 对数据增加一个新维度(batch),因为tensor的参数是[batch, channel, height, width]
unsqueeze 指定哪个下标就增加那个下标长度为1的维度(必须加上下标,)。
squeeze 从中删除长度为1的维度。指定下标就删除那个下标长度为1的维度,若不为1则不删除。
具体详情见如下代码:
unsquenze(增加长度1的维度):
import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape) # 查看a的形状
a = a.unsqueeze(0)
print(a) # 经过unsqueeze函数后的情况
print(a.shape) # 维度变化情况
# 结果输出
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]]], dtype=torch.float64)
torch.Size([1, 3, 3])
import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape) # 查看a的形状
a = a.unsqueeze(1) # 等同于a = a.unsqueeze(-2)
print(a) # 经过unsqueeze函数后的情况
print(a.shape) # 维度变化情况
# 结果输出
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0., 1., 2.]],
[[3., 4., 5.]],
[[6., 7., 8.]]], dtype=torch.float64)
torch.Size([3, 1, 3])
import torch
a = torch.arange(9,dtype=float).reshape((3,3))
print(a)
print(a.shape) # 查看a的形状
a = a.unsqueeze(2) # 等同于a = a.unsqueeze(-1)
print(a) # 经过unsqueeze函数后的情况
print(a.shape) # 维度变化情况
# 结果输出
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
tensor([[[0.],
[1.],
[2.]],
[[3.],
[4.],
[5.]],
[[6.],
[7.],
[8.]]], dtype=torch.float64)
torch.Size([3, 3, 1])
squeeze(减少长度为1的维度):
import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape) # 查看s的形状
a = a.squeeze() # 等同于a = a.unsqueeze(-1)
print(a) # 经过squeeze函数后的情况
print(a.shape) # 维度变化情况
# 结果
tensor([[[[0., 1., 2.]],
[[3., 4., 5.]],
[[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]], dtype=torch.float64)
torch.Size([3, 3])
import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape) # 查看s的形状
a = a.squeeze(0) # 此时删除a的下标为0的维度长度为1
print(a) # 经过squeeze函数后的情况
print(a.shape) # 维度变化情况
tensor([[[[0., 1., 2.]],
[[3., 4., 5.]],
[[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[0., 1., 2.]],
[[3., 4., 5.]],
[[6., 7., 8.]]], dtype=torch.float64)
torch.Size([3, 1, 3])
import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape) # 查看s的形状
a = a.squeeze(2) # 此时删除a的下标为2的维度长度为1(等价于a = a.squeeze(-2))
print(a) # 经过squeeze函数后的情况
print(a.shape) # 维度变化情况
# 结果
tensor([[[[0., 1., 2.]],
[[3., 4., 5.]],
[[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]]], dtype=torch.float64)
torch.Size([1, 3, 3])
import torch
a = torch.arange(9,dtype=float).reshape((1,3,1,3))
print(a)
print(a.shape) # 查看s的形状
a = a.squeeze(1) # 此时小标为1或3时,对应的维度长度均为3,所以squeeze没有起作用
print(a) # 经过squeeze函数后的情况
print(a.shape) # 维度变化情况
# 结果
tensor([[[[0., 1., 2.]],
[[3., 4., 5.]],
[[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])
tensor([[[[0., 1., 2.]],
[[3., 4., 5.]],
[[6., 7., 8.]]]], dtype=torch.float64)
torch.Size([1, 3, 1, 3])