torch.squeeze
作用:去除size为1的维度。当维度大于等于2时,squeeze()无作用。
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2]) # 可以看出size=1的维度有第二维(对应下标1)和第四维(对应下标3)
>>> y = torch.squeeze(x) # 去除size为1的维度,即第二维和第四维
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0) # 第一维size为2,squeeze无作用
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1) # 去除第二维
>>> y.size()
torch.Size([2, 2, 1, 2])
>>> y = torch.squeeze(x, 2) # 第三维size为2,squeeze无作用
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = x.squeeze(-1) # 如果x的最后一维size为1,则去掉
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
torch.unsqueeze
作用:与squeeze()作用相反,用于添加维度。unsqueeze必须指明维度
>>> x = torch.zeros(2, 1, 3, 4, 2)
>>> x.size()
torch.Size([2, 1, 3, 4, 2])
>>>y = torch.unsqueeze(x,0) # 在x的第0维前添加一个维度
>>>y.size()
torch.Size([1, 2, 1, 3, 4, 2])
>>>y = torch.unsqueeze(x,1) # 在x的第1维前添加一个维度
>>>y.size()
torch.Size([2, 1, 1, 3, 4, 2])