torch.unsqueeze与torch.squeeze

unsqueeze

扩充数据维度,在从0开始的指定位置上增加一维(维度为1)

x = torch.rand(2,3)
y = torch.unsqueeze(x, 1)
y = torch.unsqueeze(y, 0)

print(x.shape)
print(y.shape)
>>torch.Size([2, 3])
>>torch.Size([1, 2, 1, 3])

也可以倒着数,比如torch.unsqueeze(x,-1),就是在最后添加一维

squeeze

维度压缩,在从0开始的指定位置上,去掉维数为1的的维度

  • 若不指定参数,删除所有为 1 的维度
  • 若指定参数 N
    • 如果第 N 个位置的维度为 1 ,则删除该维度
    • 否则,不受影响
x = torch.rand(2,3)

#增加两个维度
y = torch.unsqueeze(x, 1)
y = torch.unsqueeze(y, 0)

#若第二个位置的维度为 1,则删除。否则,不受影响
z = torch.squeeze(y, 2)

#删除所有为 1 的维度
m = torch.squeeze(y)
print(x.shape)
>>torch.Size([2, 3])

print(y.shape)
>>torch.Size([1, 2, 1, 3])

print(z.shape)
>>torch.Size([1, 2, 3])

print(m.shape)
>>torch.Size([2, 3])

你可能感兴趣的:(PyTorch,pytorch,深度学习,机器学习)