PyTorch中的unsqueeze函数(自用)

前言

最近在学习swin_transformer的模型搭建,其中用到了广播机制,在理解广播机制的过程中发现自己对torch.unsqueeze()函数比较困惑,所以做了个小实验帮助自己理解。

问题阐述

我们都知道,torch.unsqueeze()函数的作用是拓展张量维度,那么在不同位置拓展之后,原数据是怎样排列的呢?下面进入实验部分。

实验

>>> import torch
>>>
>>> a = torch.Tensor([1,2,3,4,5,6,7,8,9])
>>> print(a)
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>> b = a.view(3,3)
>>> print(b)
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
>>> c = b.unsqueeze(1)
>>> print(c)
tensor([[[1., 2., 3.]],

        [[4., 5., 6.]],

        [[7., 8., 9.]]])
>>> print(c.size())
torch.Size([3, 1, 3])
>>> d = b.unsqueeze(2)
>>> print(d)
tensor([[[1.],
         [2.],
         [3.]],

        [[4.],
         [5.],
         [6.]],

        [[7.],
         [8.],
         [9.]]])
>>> print(d.size())
torch.Size([3, 3, 1])
>>> e = b.unsqueeze(0)
>>> print(e)
tensor([[[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]])
  1. 创建Tensor变量a,此时a.shape = ([9]),这里解释为有9个元素。将a变成3行3列的矩阵b,此时b.shape = ([3,3]),这里解释为有3个通道,每个通道中有3个元素。
  2. c = b.unsqueeze(1),此时c.shape = ([3,1,3]),这里解释为有3个通道,每个通道有1行,每行有3列。
  3. d = b.unsqueeze(2),此时d.shape = ([3,3,1]),这里解释为有3个通道,每个通道有3行,每行有1列。

值得注意的是,上述步骤2,3中元素排列的方式不一样,2中元素水平排列,3中元素竖直排列,因而使用广播机制时需要进行复制的数值也不一样。

总结&拓展

至此,实验结束,对torch.unsqueeze()函数的理解加深不少。下面放上广播机制相关博文供参考。

PyTorch | 广播机制(broadcast)_pytorch broadcast-CSDN博客

 

你可能感兴趣的:(pytorch,pytorch,python,经验分享)