Pytorch基础 - 1. torch.squeeze() 和 unsqueeze()

tensor升维和降维是神经网络的基本操作,比如不同维feature融合等都需要改操作。常用的函数有torch.unsqueeze() 和 torch.unsqueeze()操作。

目录

1. tensor降维操作: torch.squeeze() 和 指定index 

2. tensor升维操作: torch.unsqueeze() 和 使用None

 3. torch.squeeze和torch.unsqueeze的另一种写法


1. tensor降维操作: torch.squeeze() 和 指定index 

(1) 使用torch.squeeze(input,dim),默认删除tensor中所有维度为1的维度,也可指定dim。torch.squeeze — PyTorch 1.13 documentation

import torch

if __name__ == '__main__':
    a = torch.randn((2, 1, 3, 1, 4))
    a1 = torch.squeeze(a)
    print(a1.shape)  # torch.Size([2, 3, 4])
    a2 = torch.squeeze(a, dim=1)
    print(a2.shape)  # torch.Size([2, 3, 1, 4])
    a3 = torch.squeeze(a, dim=3)
    print(a3.shape)  # torch.Size([2, 1, 3, 4])

(2) 也可使用index=0直接指定,使用torch.equal比较两者相等。

if __name__ == '__main__':
    a = torch.randn((2, 1, 3, 1, 4))
    a1 = torch.squeeze(a)
    print(a1.shape)  # torch.Size([2, 3, 4])

    a2 = a[:, 0, :, 0]
    print(a2.shape)  # torch.Size([2, 3, 4])

    print(torch.equal(a1, a2))  # True

2. tensor升维操作: torch.unsqueeze() 和 使用None

(1) torch.unsqueeze(input, dim) ,对指定的dim,执行升维操作,具体可参考官方文档以及如下示例。torch.unsqueeze — PyTorch 1.13 documentation

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = torch.unsqueeze(a, dim=1)
    print(a1.shape)  # torch.Size([2, 1, 3, 4])
    a2 = torch.unsqueeze(a, dim=2)
    print(a2.shape)  # torch.Size([2, 3, 1, 4])

(2) 简单用法:使用None,使用None来增加新维度

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = a[:, None, ...]
    print(a1.shape)  # torch.Size([2, 1, 3, 4])
    a2 = a[..., None, :]
    print(a2.shape)  # torch.Size([2, 3, 1, 4])

注意:a1中None后面的三个点可以省略,如下

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))

    a1_old = a[:, None, ...]
    print(a1_old .shape)  # torch.Size([2, 1, 3, 4])
    a1_new = a[:, None]
    print(a1_new .shape)  # torch.Size([2, 1, 3, 4])

    print(torch.equal(a1_old, a1_new))  # True

 3. torch.squeeze和torch.unsqueeze的另一种写法

一般情况下使用torch.squeeze(x, dim=?)来进行降维,当然还可以直接使用 x.squeeze(dim=?)。

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = torch.unsqueeze(a, dim=0)
    print(a1.shape)  # torch.Size([1, 2, 3, 4])
    # 另一种写法
    a2 = a.unsqueeze(dim=0)
    print(a2.shape)  # torch.Size([1, 2, 3, 4])

你可能感兴趣的:(#,Pytorch操作,深度学习(CV),pytorch,深度学习,torch.squeeze)