最近在学习swin_transformer的模型搭建,其中用到了广播机制,在理解广播机制的过程中发现自己对torch.unsqueeze()函数比较困惑,所以做了个小实验帮助自己理解。
我们都知道,torch.unsqueeze()函数的作用是拓展张量维度,那么在不同位置拓展之后,原数据是怎样排列的呢?下面进入实验部分。
>>> import torch
>>>
>>> a = torch.Tensor([1,2,3,4,5,6,7,8,9])
>>> print(a)
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>> b = a.view(3,3)
>>> print(b)
tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
>>> c = b.unsqueeze(1)
>>> print(c)
tensor([[[1., 2., 3.]],
[[4., 5., 6.]],
[[7., 8., 9.]]])
>>> print(c.size())
torch.Size([3, 1, 3])
>>> d = b.unsqueeze(2)
>>> print(d)
tensor([[[1.],
[2.],
[3.]],
[[4.],
[5.],
[6.]],
[[7.],
[8.],
[9.]]])
>>> print(d.size())
torch.Size([3, 3, 1])
>>> e = b.unsqueeze(0)
>>> print(e)
tensor([[[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]]])
值得注意的是,上述步骤2,3中元素排列的方式不一样,2中元素水平排列,3中元素竖直排列,因而使用广播机制时需要进行复制的数值也不一样。
至此,实验结束,对torch.unsqueeze()函数的理解加深不少。下面放上广播机制相关博文供参考。
PyTorch | 广播机制(broadcast)_pytorch broadcast-CSDN博客