import torch
# 创建一个3X3的tensor
a = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
a.shape
torch.Size([3, 3])
# dim的范围在[-input.dim() - 1, input.dim() + 1), 左闭右开区间,如果dim是负数,dim = dim + input.dim() + 1.
b0 = a.unsqueeze(0) # 相当于:a.unsqueeze(-3) # dim = -3 + 2 + 1 =0
b1 = a.unsqueeze(1) # 相当于:a.unsqueeze(-2) # dim = -2 + 2 + 1 =1
b2 = a.unsqueeze(2) # 相当于:a.unsqueeze(-1) # dim = -1 + 2 + 1 =2
b0.shape, b1.shape, b2.shape
(torch.Size([1, 3, 3]), torch.Size([3, 1, 3]), torch.Size([3, 3, 1]))
# 我们再给所有b*的dim=1加一个维度
b00 = b0.unsqueeze(1)
b11 = b1.unsqueeze(1)
b22 = b2.unsqueeze(1)
b00.shape, b11.shape, b22.shape,
(torch.Size([1, 1, 3, 3]), torch.Size([3, 1, 1, 3]), torch.Size([3, 1, 3, 1]))
# dim为空时, 去除所有为1的维度
c0 = b00.squeeze()
c1 = b11.squeeze()
c2 = b22.squeeze()
c0.shape, c1.shape, c2.shape
(torch.Size([3, 3]), torch.Size([3, 3]), torch.Size([3, 3]))
# dim为不为空时, 去除指定的dim=1的维度
# 以 b22(torch.Size([3, 1, 3, 1]))) 为例
d0 = b22.squeeze(0) # 不变
d1 = b22.squeeze(1) # 维度减少
d2 = b22.squeeze(2) # 不变
d3 = b22.squeeze(3) # 维度减少
d4 = b22.squeeze(-3)
d0.shape, d1.shape, d2.shape, d3.shape, d4.shape
(torch.Size([3, 1, 3, 1]),
torch.Size([3, 3, 1]),
torch.Size([3, 1, 3, 1]),
torch.Size([3, 1, 3]),
torch.Size([3, 3, 1]))
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])