pytorch中的torch.unsqueeze和squeeze张量维度变化问题

顾名思义:unsqueeze,扩展维度,返回一个新的张量,对输入的既定位置插入维度 1

                  squeeze,压缩维度,将输入张量形状中的1 去除并返回。

torch.unsqueeze(input, dim)

torch.squeeze(input, dim)

  • tensor (Tensor) – 输入张量
  • dim (int) – 插入/消除 维度的索引

以下用一个二维张量进行举例:

压缩维度仅对(0,1)索引进行示例,(-1,-2)原理类似

import torch

x = torch.Tensor([[1, 2, 3, 4],
                 [5,6,7,8]])  
print('#' * 50)
print(x)  
print(x.size())  
print(x.dim())  

##########
print('#' * 50)
print(torch.unsqueeze(x, 0))  
print(torch.unsqueeze(x, 0).size())  
print(torch.unsqueeze(x, 0).dim())  
m=torch.unsqueeze(x, 0)
print(m.squeeze(0))
n=m.squeeze(0)
print(n.size())
print(n.dim())

##########
print('#' * 50)
print(torch.unsqueeze(x, 1))
print(torch.unsqueeze(x, 1).size())  
print(torch.unsqueeze(x, 1).dim())  
a=torch.unsqueeze(x, 1)
print(a.squeeze(1))
b=a.squeeze(1)
print(b.size())
print(b.dim())

##########
print('#' * 50)
print(torch.unsqueeze(x, -1))
print(torch.unsqueeze(x, -1).size())  
print(torch.unsqueeze(x, 1).dim())

##########
print('#' * 50)
print(torch.unsqueeze(x, -2))  
print(torch.unsqueeze(x, -2).size())  
print(torch.unsqueeze(x, -2).dim())  

相应结果:

##################################################
tensor([[1., 2., 3., 4.], 
        [5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1., 2., 3., 4.],
         [5., 6., 7., 8.]]])
torch.Size([1, 2, 4])
3
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1., 2., 3., 4.]],

        [[5., 6., 7., 8.]]])
torch.Size([2, 1, 4])
3
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1.],
         [2.],
         [3.],
         [4.]],

        [[5.],
         [6.],
         [7.],
         [8.]]])
torch.Size([2, 4, 1])
3
##################################################
tensor([[[1., 2., 3., 4.]],

        [[5., 6., 7., 8.]]])
torch.Size([2, 1, 4])
3

你可能感兴趣的:(神经网络模型,windows,后端,mssql,python,开发语言)