torch.stack(inputs, dim=0)
inputs:同样也是张量序列,沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
torch.cat(inputs, dim=0)
inputs:必须是张量序列, 在给定维度上对输入的张量序列进行连接操作,序列中所有的张量都应该为相同形状。
#注:由字面意思可以看出, cat可以理解为续接,不会增加维度,stack可以理解为叠加,会新加增加一个维度(增加的维度根据输入的dim而定)。
x1 = torch.tensor([[1,2,3], [4,5,6]])# x1.shape = tensor.size([2,3])
x2 = torch.tensor([[7,8,9], [10,11,12]])
x = [x1, x2]
print(x1.shape)
# print(len(c), c[1].shape)
print('沿第0维进行操作:')
y1 = torch.stack(x, dim=0)
y2 = torch.cat(x, dim=0)
print('y1:', y1.shape,'\n',y1)
print('y2:', y2.shape,'\n',y2)
输出为
torch.Size([2, 3])
沿第0维进行操作:
y1: torch.Size([2, 2, 3])
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
y2: torch.Size([4, 3])
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
从y1的输出可以看到,stack直接将x1和x2的第0维进行叠加,即输出为[x1, x2]
,shape由[2,3]
变为[2, 2, 3]
。从y2的输出可以看到,cat在第0维将x1和x2元素进行续接,即输出为[x1[0], x1[1], x2[0], x2[1]]
, shape由[2, 3]
变为[4,3]
。
print('沿第1维进行操作:')
y1 = torch.stack(x, dim=1)
y2 = torch.cat(x, dim=1)
print('y1:', y1.shape,'\n',y1)
print('y2:', y2.shape,'\n',y2)
输出为:
沿第1维进行操作:
y1: torch.Size([2, 2, 3])
tensor([[[ 1, 2, 3],
[ 7, 8, 9]],
[[ 4, 5, 6],
[10, 11, 12]]])
y2: torch.Size([2, 6])
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
从y1的输出可以看到,stack直接将x1和x2相对应的第1维的元素进行叠加,即输出为[[x1[0], x2[0]], [x1[1], x2[1]]
,shape由[2,3]
变为[2, 2, 3]
。。
从y2的输出可以看到,cat将x1和x2相对应的第1维的元素进行续接, shape由[2,3]
变为[2, 6]
。。拿x1[0]和x2[0]举例来说,x1[0] = [1, 2, 3], x2[0] = [7, 8, 9]
,对它们进行续接即可得到[1, 2, 3, 7, 8, 9]
。
对于torch.stack需要特别注意的是,其插入的维度是介于 0 与 待连接的张量序列数之间,stack的输入维度范围要比cat大1。对于以上例子,因为x1.shape=[2,3]
,所以不能torch.cat(x, dim=2)
。但可以torch.stack(x, dim=2)
,其输出的shape为[2, 3, 2]
,结果为
>>>torch.stack(x, dim=2)
tensor([[[ 1, 7],
[ 2, 8],
[ 3, 9]],
[[ 4, 10],
[ 5, 11],
[ 6, 12]]])
从以上结果可以看出,torch.stack(x, dim=2)
是将x1[i][j]
和x2[i][j]
堆叠在一起的。如x1[0][0]=1
和x2[0][0]=7
堆叠在一起,得到[1, 7]