关于torch.unsqueeze与torch.squeeze

torch内置的unsqueeze方法是增加1的维度,squeeze和unsqueeze相反,是删除数据中1的维度,当指定数据时,删除指定位置的1的维度,当不指定参数时,删除数据中所有1的维度,具体操作如下:​​​​​​​

a =torch.ones(3,3)
b = torch.unsqueeze(a,1)#第二个参数指定在哪一个维度增加1
#b.shape为([3,1,3])
c = b.squeeze()
c.shape
#结果为([3,3])

除了squeeze和unsqueeze之外,还有另一种方法可以增加或删除维度,即利用[None]的形式,具体操作如下,当然也可以指定在某一维度中添加。使用这种方法如果想删除维度可以采用切片的方式将其取出,也可以使用squeeze方法。

a = torch.ones(3,3)
b = a[None]
#b的shape变为[1,3,3]
c = a[:,None,None]
#c的shape变为[3,3,1,1]

你可能感兴趣的:(python,pytorch)