Pytorch中 torch.squeeze() 和torch.unsqueeze()的用法

  • torch.squeeze()
    作用:维度压缩或者解压。
    torch.squeeze(input, dim=None, out=None)
    将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
    注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
    参数:
    input (Tensor) – 输入张量
    dim (int, optional) – 如果给定,则input只会在给定维度挤压
    out (Tensor, optional) – 输出张量
  • 示例
>>> x = torch.zeros(1,1,2,1,3)
>>> x.dim()
5
>>> torch.squeeze(x).size()#去掉dim=1的维度
torch.Size([2, 3])
>>> torch.squeeze(x,0).size()  # dim=0表示第一维,且第一维的维度为1,所以去掉
torch.Size([1, 2, 1, 3])
>>> torch.squeeze(x,3).size()
torch.Size([1, 1, 2, 3])
>>> torch.squeeze(x,2).size()  # dim=2,第三维的维度为2!=1,所以不变
torch.Size([1, 1, 2, 1, 3])
  • torch.unsqueeze()
    作用:扩展维度
    torch.unsqueeze(input, dim, out=None)
    返回一个新的张量,对输入的既定位置插入维度 1

参考

https://github.com/shanglianlm0525/PyTorch-Networks

你可能感兴趣的:(Pytorch中 torch.squeeze() 和torch.unsqueeze()的用法)