Pytorch unsqueeze和squeeze方法


1.维度扩展(unsqueeze)

torch.unsqueeze(tensor, dim)将tensor的指定位置插入1维

import torch
x = torch.arange(10).reshape(2, 5)
x
# tensor([[0, 1, 2, 3, 4],
#         [5, 6, 7, 8, 9]])

torch.unsqueeze(x, 0).shape # torch.Size([1, 2, 5])
torch.unsqueeze(x, 1).shape # torch.Size([2, 1, 5])

torch.unsqueeze(x, 2).shape # torch.Size([2, 5, 1])
torch.unsqueeze(x, -1).shape # torch.Size([2, 5, 1])

2.维度缩减(squeeze)

tensor.squeeze(dim)缩减tensor的指定位置的维度,如果该维度不为1,不做缩减

xx = torch.arange(10*10*2).reshape(10, 10, 2, 1)

# 指定缩减的维度为10,不做处理
xx.squeeze(1).shape # torch.Size([10, 10, 2, 1])

# 指定缩减的第4维的维度位1,缩减
xx.squeeze(-1).shape # torch.Size([10, 10, 2])

你可能感兴趣的:(pytorch)