函数功能:与squeeze()函数功能相反,用于添加维度;
>>> a = torch.Tensor(3)
>>> a
tensor([1.7718e+28, 1.0509e-38, 0.0000e+00])
>>> a.unsqueeze(0) # 扩展第一个维度
tensor([[1.7718e+28, 1.0509e-38, 0.0000e+00]])
>>> a.unsqueeze(1) #扩展第二个维度
tensor([[1.7718e+28],
[1.0509e-38],
[0.0000e+00]])
squeeze本身有挤压的意思;
函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用;
其中squeeze(0)代表若第一维度值为1则去除第一维度;
squeeze(1)代表若第二维度值为1则去除第二维度;
>>> a = torch.Tensor(3,2)
>>> a
tensor([[-6.5850e+34, 4.5759e-41],
[-6.5850e+34, 4.5759e-41],
[ 0.0000e+00, 0.0000e+00]])
>>> a.unsqueeze(0)
tensor([[[-6.5850e+34, 4.5759e-41],
[-6.5850e+34, 4.5759e-41],
[ 0.0000e+00, 0.0000e+00]]])
>>> a.squeeze(0) # 第一维度会被缩减
tensor([[-6.5850e+34, 4.5759e-41],
[-6.5850e+34, 4.5759e-41],
[ 0.0000e+00, 0.0000e+00]])
>>> a.squeeze(1)# 第二维度不会被缩减
tensor([[-6.5850e+34, 4.5759e-41],
[-6.5850e+34, 4.5759e-41],
[ 0.0000e+00, 0.0000e+00]])
>>> a.unsqueeze(-1) # 表示最后一个维度
tensor([[[-6.5850e+34],
[ 4.5759e-41]],
[[-6.5850e+34],
[ 4.5759e-41]],
[[ 0.0000e+00],
[ 0.0000e+00]]])
>>> a.squeeze(-1)
tensor([[-6.5850e+34, 4.5759e-41],
[-6.5850e+34, 4.5759e-41],
[ 0.0000e+00, 0.0000e+00]])
参考: 【学习笔记】pytorch中squeeze()和unsqueeze()函数介绍_Jaborie203的博客-CSDN博客_pytorch unsqueeze