在讲两者区别的之前我们首先看一下官方的定义:
共同点:两个函数都是将tensor数据在指定的维度上进行一个拼接处理,并且要保证进行拼接前的数据形状是一样的
区别在于:torch.stack()处理之后会增加一个维度,选择的维度可以超过自身维度范围;torch.cat()处理之后维度不会变,选择的维度不能超过自身范围
看以下几个例子:
一、torch.stack()
>>> a = torch.tensor([[1,2],[3,4]])
>>> a.shape
torch.Size([2, 2])
>>> b = torch.tensor([5,6],[7,8])
>>> b.shape
torch.Size([2, 2])
** dim=0时 **
>>> c=torch.stack([a,b],dim=0)
>>> c
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
>>> c.shape
torch.Size([2, 2, 2])
**dim=1时**
>>> d = torch.stack([a,b],dim=1)
>>> d
tensor([[[1, 2],
[5, 6]],
[[3, 4],
[7, 8]]])
>>> d.shape
torch.Size([2, 2, 2])
**dim =2时**
>>> e = torch.stack([a,b],dim=2)
>>> e
tensor([[[1, 5],
[2, 6]],
[[3, 7],
[4, 8]]])
>>> e.shape
torch.Size([2, 2, 2])
上面的例子可以看出在torch.stack()不同的维度进行拼接,结果是不一样的,但是最后的维度都是从2维变成了3维度
看了之后或许有小伙伴和当时的我一样还是懵懵懂懂的,下面我写一下另一种理解方式(个人理解,如果有不对的地方欢迎指正),看完一定会明白!!!
我们还是以上面的几个程序为例子:
a = [ a [ 0 ] [ 0 ] a [ 0 ] [ 1 ] a [ 1 ] [ 0 ] a [ 1 ] [ 1 ] ] a = \;\left[ \begin{array}{l} a[0][0]\;\;a[0][1]\\ a[1][0]\;\;a[1][1] \end{array} \right] a=[a[0][0]a[0][1]a[1][0]a[1][1]]
b = [ b [ 0 ] [ 0 ] b [ 0 ] [ 1 ] b [ 1 ] [ 0 ] b [ 1 ] [ 1 ] ] b = \;\left[ \begin{array}{l} b[0][0]\;\;b[0][1]\\ b[1][0]\;\;b[1][1] \end{array} \right] b=[b[0][0]b[0][1]b[1][0]b[1][1]]
1、在dim=0维度进行拼接时,首先会在dim=0处添加一个维度变成3个维度,上面两组tensor数据分别变成:
按照坐标进行排序则变成拼接之后的数据,形状为(2,2,2)
,拼接顺序为(a,b),所以a的第一维度是0,b的第一维度是1,
那么结果为
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
2、当dim=1维度进行拼接时候首先会在dim=1处添加一个维度变成3个维度
按照下标进行排序之后如下:
画线的部分是对应的原数据,没有画线的是新增的一个维度
所以最终结果为
tensor([[[1, 2],
[5, 6]],
[[3, 4],
[7, 8]]])
3、当dim=2维度进行拼接时候,我们还是按照前面的两个例子来操作
然后再进行排序,最后的结果为:
tensor([[[1, 5],
[2, 6]],
[[3, 7],
[4, 8]]])
二、torch.cat()
该函数不用新增一个维度,所以理解起来就比较容易,以下是几个例子:
**dim=0**
>>> c = torch.cat([a,b],dim=0)
>>> c
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
**dim=1**
>>> d = torch.cat([a,b],dim=1)
>>> d
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])
>>> d.shape
torch.Size([2, 4])
**dim=2** **当dim=2时会报错,因为超出了自身维度范围**
>>> e = torch.cat([a,b],dim=2)
Traceback (most recent call last):
File "" , line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
对于torch.cat()的理解,可以理解为当dim=0时在0维度上拼接(或者理解为是竖着拼接),dim=1时在1维度上拼接(或者理解为横着拼接)
还是以上面a,b两个tensor为例子:
1、dim=0时
c = [ a [ 0 ] [ 0 ] a [ 0 ] [ 1 ] a [ 1 ] [ 0 ] a [ 1 ] [ 1 ] b [ 0 ] [ 0 ] b [ 0 ] [ 1 ] b [ 1 ] [ 0 ] b [ 1 ] [ 1 ] ] c = \;\left[ \begin{array}{l} a[0][0]\;\;a[0][1]\\ a[1][0]\;\;a[1][1]\\ b[0][0]\;\;b[0][1]\\ b[1][0]\;\;b[1][1] \end{array} \right] c=⎣⎢⎢⎡a[0][0]a[0][1]a[1][0]a[1][1]b[0][0]b[0][1]b[1][0]b[1][1]⎦⎥⎥⎤
tensor([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
2、dim=1时
d = [ a [ 0 ] [ 0 ] a [ 0 ] [ 1 ] b [ 0 ] [ 0 ] b [ 0 ] [ 1 ] a [ 1 ] [ 0 ] a [ 1 ] [ 1 ] b [ 1 ] [ 0 ] b [ 1 ] [ 1 ] ] d = \;\left[ \begin{array}{l} a[0][0]\;\;a[0][1]\;\;b[0][0]\;\;b[0][1]\\ a[1][0]\;\;a[1][1]\;\;b[1][0]\;\;b[1][1] \end{array} \right] d=[a[0][0]a[0][1]b[0][0]b[0][1]a[1][0]a[1][1]b[1][0]b[1][1]]
tensor([[1, 2, 5, 6],
[3, 4, 7, 8]])