squeeze和unsqueeze

1.unqueeze-指定dim添加一个维度

import torch
# 创建一个3X3的tensor
a = torch.tensor([[1, 2, 3], 
                  [4, 5, 6], 
                  [7, 8, 9]])
a.shape

torch.Size([3, 3])

# dim的范围在[-input.dim() - 1, input.dim() + 1), 左闭右开区间,如果dim是负数,dim = dim + input.dim() + 1.

b0 = a.unsqueeze(0) # 相当于:a.unsqueeze(-3)  # dim = -3 + 2 + 1 =0
b1 = a.unsqueeze(1) # 相当于:a.unsqueeze(-2)  # dim = -2 + 2 + 1 =1
b2 = a.unsqueeze(2) # 相当于:a.unsqueeze(-1) # dim = -1 + 2 + 1 =2

b0.shape, b1.shape, b2.shape

(torch.Size([1, 3, 3]), torch.Size([3, 1, 3]), torch.Size([3, 3, 1]))

2.squeeze-指定dim减少一个维度

# 我们再给所有b*的dim=1加一个维度
b00 = b0.unsqueeze(1) 
b11 = b1.unsqueeze(1) 
b22 = b2.unsqueeze(1) 

b00.shape, b11.shape, b22.shape, 

(torch.Size([1, 1, 3, 3]), torch.Size([3, 1, 1, 3]), torch.Size([3, 1, 3, 1]))

# dim为空时, 去除所有为1的维度
c0 = b00.squeeze()
c1 = b11.squeeze()
c2 = b22.squeeze()

c0.shape, c1.shape, c2.shape

(torch.Size([3, 3]), torch.Size([3, 3]), torch.Size([3, 3]))

# dim为不为空时, 去除指定的dim=1的维度
# 以 b22(torch.Size([3, 1, 3, 1]))) 为例
d0 = b22.squeeze(0)  # 不变
d1 = b22.squeeze(1)  #  维度减少
d2 = b22.squeeze(2)  # 不变
d3 = b22.squeeze(3)  #  维度减少
d4 = b22.squeeze(-3)
d0.shape, d1.shape, d2.shape, d3.shape, d4.shape

(torch.Size([3, 1, 3, 1]),
torch.Size([3, 3, 1]),
torch.Size([3, 1, 3, 1]),
torch.Size([3, 1, 3]),
torch.Size([3, 3, 1]))

>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

你可能感兴趣的:(LIST,python,深度学习,开发语言)