pytorch 中 squeeze 和unsqueeze函数

1. torch.squeeze() 函数 :

作用:移除指定或所有维数为1的维度,从而得到维度减少的张量

解释一下:

x=torch.zeros(5,1,1,1)

print(x)

'输出'
tensor([[[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]]])

举个极端点的例子,这是一个4维的数组,除了第0个维度之外每个维度的维数均为1

也就是说,每一个0都被3个括号括着,这显然不太合理

下面调用squeeze函数:

y = x.squeeze()
print(y)
print(y.shape)

'输出'
tensor([0., 0., 0., 0., 0.])
torch.Size([5])

瞬间一系列的括号都没有了,是不是看着舒服了许多?

进一步:

y = x.squeeze(1)
print(y)
print(y.shape)

'输出'
tensor([[[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]])
torch.Size([5, 1, 1])

这里添加了参数1,这样就只压缩了第1个维度(计数从0开始),一个0被2个括号括着

但压缩的前提是,该张量必须有维数为1的维度,比如:

y = x.squeeze(0)
print(y)
print(y.shape)

a = torch.tensor([[1, 1, 1], [2, 2, 2]])
b = a.squeeze()
print(b)
print(b.shape)

'输出'
tensor([[[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]]])
torch.Size([5, 1, 1, 1])

tensor([[1, 1, 1],
        [2, 2, 2]])
torch.Size([2, 3])

y和b相对于x和a均没有发生变化,原因就是:x的第0个维度,维数不是1;a中更是没有维数为1的维度

另外:x.squeeze() 或者 torch.squeeze(x) 都不会让x发生改变

y = x.squeeze()
print(x.shape)
print(y.shape)


'输出'
torch.Size([5, 1, 1, 1])
torch.Size([5])

2. torch.unsqueeze() 函数 :

作用:在张量的制定维度插入新的维度得到维度提升的张量

举个例子:

 x= torch.zeros(5)
print(x)
print(x.shape)

'输出'
tensor([0, 0, 0, 0, 0])
torch.Size([5])

一维张量,总共5个0,接下来依次操作:

y = x.unsqueeze(dim=0)
print(y)
print(y.shape)

y = x.unsqueeze(dim=1)
print(y)
print(y.shape)

z = y.unsqueeze(dim=2)
print(z)
print(z.shape)

'输出'
tensor([[0., 0., 0., 0., 0.]])
torch.Size([1, 5])

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.]])
torch.Size([5, 1])

tensor([[[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]])
torch.Size([5, 1, 1])

把第0维进行扩张,就是在最外面加了一个括号

把第1维进行扩张,就是把里面的每个元素元素(也可以理解成是,扩充后的第1维,也就是0.)都加一个括号

继续套,选择dim=2,还是把最内层的了(也可以理解成是,扩充后的第2维,也就是0.),都加一个括号

再举个例子:

a = torch.tensor([[1, 1, 1], [2, 2, 2]])
print(a)
print(a.shape)

b = a.unsqueeze(dim=0)
print(b)
print(b.shape)

b = a.unsqueeze(dim=1)
print(b)
print(b.shape)

b = a.unsqueeze(dim=2)
print(b)
print(b.shape)

'输出'
tensor([[1, 1, 1],
        [2, 2, 2]])
torch.Size([2, 3])

tensor([[[1, 1, 1],
         [2, 2, 2]]])
torch.Size([1, 2, 3])

tensor([[[1, 1, 1]],
        [[2, 2, 2]]])
torch.Size([2, 1, 3])

tensor([[[1],
         [1],
         [1]],
        [[2],
         [2],
         [2]]])
torch.Size([2, 3, 1])

怎么套的括号,是不是一目了然~

同样:x.unsqueeze() 或者 torch.unsqueeze(x) 都不会让x发生改变

x = torch.tensor([[1, 1, 1], [2, 2, 2]])
print(x)
y = x.unsqueeze(dim=0)
print(y.shape)
print(x.shape)


'输出'
torch.Size([1, 2, 3])
torch.Size([2, 3])

你可能感兴趣的:(pytorch,python,深度学习)