torch.squeeze和torch.unsqueeze

torch.squeeze
作用:去除size为1的维度。当维度大于等于2时,squeeze()无作用。

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])  # 可以看出size=1的维度有第二维(对应下标1)和第四维(对应下标3)
	
>>> y = torch.squeeze(x)  # 去除size为1的维度,即第二维和第四维
>>> y.size()
torch.Size([2, 2, 2])

>>> y = torch.squeeze(x, 0) # 第一维size为2,squeeze无作用
>>> y.size()
torch.Size([2, 1, 2, 1, 2])

>>> y = torch.squeeze(x, 1) # 去除第二维
>>> y.size()
torch.Size([2, 2, 1, 2])

>>> y = torch.squeeze(x, 2) # 第三维size为2,squeeze无作用
>>> y.size()
torch.Size([2, 1, 2, 1, 2])

>>> y = x.squeeze(-1)  # 如果x的最后一维size为1,则去掉
>>> y.size()
torch.Size([2, 1, 2, 1, 2])

torch.unsqueeze
作用:与squeeze()作用相反,用于添加维度。unsqueeze必须指明维度

>>> x = torch.zeros(2, 1, 3, 4, 2)
>>> x.size()
torch.Size([2, 1, 3, 4, 2])

>>>y = torch.unsqueeze(x,0)  # 在x的第0维前添加一个维度
>>>y.size()
torch.Size([1, 2, 1, 3, 4, 2])

>>>y = torch.unsqueeze(x,1)  # 在x的第1维前添加一个维度
>>>y.size()
torch.Size([2, 1, 1, 3, 4, 2])

你可能感兴趣的:(pytorch函数,pytorch,深度学习,机器学习)