torch.cat()与torch.stack()的区别

torch.cat与torch.stack的区别

在讲两者区别的之前我们首先看一下官方的定义:
torch.cat()与torch.stack()的区别_第1张图片
torch.cat()与torch.stack()的区别_第2张图片
共同点:两个函数都是将tensor数据在指定的维度上进行一个拼接处理,并且要保证进行拼接前的数据形状是一样的
区别在于:torch.stack()处理之后会增加一个维度,选择的维度可以超过自身维度范围;torch.cat()处理之后维度不会变,选择的维度不能超过自身范围

看以下几个例子:
一、torch.stack()

>>> a = torch.tensor([[1,2],[3,4]])
>>> a.shape
torch.Size([2, 2])
>>> b = torch.tensor([5,6],[7,8])
>>> b.shape
torch.Size([2, 2])

            ** dim=0**
>>> c=torch.stack([a,b],dim=0)
>>> c
tensor([[[1, 2],
         [3, 4]],
        [[5, 6],
         [7, 8]]])
>>> c.shape
torch.Size([2, 2, 2])

             **dim=1**
>>> d = torch.stack([a,b],dim=1)
>>> d
tensor([[[1, 2],
         [5, 6]],
        [[3, 4],
         [7, 8]]])
>>> d.shape
torch.Size([2, 2, 2])
            **dim =2**
>>> e = torch.stack([a,b],dim=2)
>>> e
tensor([[[1, 5],
         [2, 6]],
        [[3, 7],
         [4, 8]]])
>>> e.shape
torch.Size([2, 2, 2])

上面的例子可以看出在torch.stack()不同的维度进行拼接,结果是不一样的,但是最后的维度都是从2维变成了3维度

看了之后或许有小伙伴和当时的我一样还是懵懵懂懂的,下面我写一下另一种理解方式(个人理解,如果有不对的地方欢迎指正),看完一定会明白!!!

我们还是以上面的几个程序为例子:

a =    [ a [ 0 ] [ 0 ]      a [ 0 ] [ 1 ] a [ 1 ] [ 0 ]      a [ 1 ] [ 1 ] ] a = \;\left[ \begin{array}{l} a[0][0]\;\;a[0][1]\\ a[1][0]\;\;a[1][1] \end{array} \right] a=[a[0][0]a[0][1]a[1][0]a[1][1]]

b =    [ b [ 0 ] [ 0 ]      b [ 0 ] [ 1 ] b [ 1 ] [ 0 ]      b [ 1 ] [ 1 ] ] b = \;\left[ \begin{array}{l} b[0][0]\;\;b[0][1]\\ b[1][0]\;\;b[1][1] \end{array} \right] b=[b[0][0]b[0][1]b[1][0]b[1][1]]

1、在dim=0维度进行拼接时,首先会在dim=0处添加一个维度变成3个维度,上面两组tensor数据分别变成:
在这里插入图片描述
在这里插入图片描述
按照坐标进行排序则变成拼接之后的数据,形状为(2,2,2)
,拼接顺序为(a,b),所以a的第一维度是0,b的第一维度是1,
那么结果为

tensor([[[1, 2],
         [3, 4]],
        [[5, 6],
         [7, 8]]])

2、当dim=1维度进行拼接时候首先会在dim=1处添加一个维度变成3个维度
在这里插入图片描述
torch.cat()与torch.stack()的区别_第3张图片
按照下标进行排序之后如下:
torch.cat()与torch.stack()的区别_第4张图片
画线的部分是对应的原数据,没有画线的是新增的一个维度
所以最终结果为

tensor([[[1, 2],
         [5, 6]],
        [[3, 4],
         [7, 8]]])

3、当dim=2维度进行拼接时候,我们还是按照前面的两个例子来操作
在这里插入图片描述torch.cat()与torch.stack()的区别_第5张图片
然后再进行排序,最后的结果为:

tensor([[[1, 5],
         [2, 6]],
        [[3, 7],
         [4, 8]]])

二、torch.cat()
该函数不用新增一个维度,所以理解起来就比较容易,以下是几个例子:

       **dim=0**
>>> c = torch.cat([a,b],dim=0)
>>> c
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
        
       **dim=1**
>>> d = torch.cat([a,b],dim=1)
>>> d
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])
>>> d.shape
torch.Size([2, 4])

       **dim=2**  **当dim=2时会报错,因为超出了自身维度范围**
>>> e = torch.cat([a,b],dim=2)
Traceback (most recent call last):
  File "", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

对于torch.cat()的理解,可以理解为当dim=0时在0维度上拼接(或者理解为是竖着拼接),dim=1时在1维度上拼接(或者理解为横着拼接)

还是以上面a,b两个tensor为例子:
1、dim=0时
c =    [ a [ 0 ] [ 0 ]      a [ 0 ] [ 1 ] a [ 1 ] [ 0 ]      a [ 1 ] [ 1 ] b [ 0 ] [ 0 ]      b [ 0 ] [ 1 ] b [ 1 ] [ 0 ]      b [ 1 ] [ 1 ] ] c = \;\left[ \begin{array}{l} a[0][0]\;\;a[0][1]\\ a[1][0]\;\;a[1][1]\\ b[0][0]\;\;b[0][1]\\ b[1][0]\;\;b[1][1] \end{array} \right] c=a[0][0]a[0][1]a[1][0]a[1][1]b[0][0]b[0][1]b[1][0]b[1][1]

tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])

2、dim=1
d =    [ a [ 0 ] [ 0 ]      a [ 0 ] [ 1 ]      b [ 0 ] [ 0 ]      b [ 0 ] [ 1 ] a [ 1 ] [ 0 ]      a [ 1 ] [ 1 ]      b [ 1 ] [ 0 ]      b [ 1 ] [ 1 ] ] d = \;\left[ \begin{array}{l} a[0][0]\;\;a[0][1]\;\;b[0][0]\;\;b[0][1]\\ a[1][0]\;\;a[1][1]\;\;b[1][0]\;\;b[1][1] \end{array} \right] d=[a[0][0]a[0][1]b[0][0]b[0][1]a[1][0]a[1][1]b[1][0]b[1][1]]


tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])

你可能感兴趣的:(python,pytorch,深度学习)