torch.cat()和torch.stack()的区别,例子解释

torch.stack

torch.stack(inputs, dim=0)

inputs:同样也是张量序列,沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

torch.cat

torch.cat(inputs, dim=0)

inputs:必须是张量序列, 在给定维度上对输入的张量序列进行连接操作,序列中所有的张量都应该为相同形状。

#注:由字面意思可以看出, cat可以理解为续接,不会增加维度,stack可以理解为叠加,会新加增加一个维度(增加的维度根据输入的dim而定)。

直接上例子, 沿第0维进行操作

x1 = torch.tensor([[1,2,3], [4,5,6]])# x1.shape = tensor.size([2,3])
x2 = torch.tensor([[7,8,9], [10,11,12]])
x = [x1, x2]
print(x1.shape)
# print(len(c), c[1].shape)
print('沿第0维进行操作:')
y1 = torch.stack(x, dim=0)
y2 = torch.cat(x, dim=0)
print('y1:', y1.shape,'\n',y1)
print('y2:', y2.shape,'\n',y2)

输出为

torch.Size([2, 3])
沿第0维进行操作:
y1: torch.Size([2, 2, 3]) 
 tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])
y2: torch.Size([4, 3]) 
 tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

从y1的输出可以看到,stack直接将x1和x2的第0维进行叠加,即输出为[x1, x2],shape由[2,3]变为[2, 2, 3]。从y2的输出可以看到,cat在第0维将x1和x2元素进行续接,即输出为[x1[0], x1[1], x2[0], x2[1]], shape由[2, 3]变为[4,3]

沿第1维进行操作

print('沿第1维进行操作:')
y1 = torch.stack(x, dim=1)
y2 = torch.cat(x, dim=1)
print('y1:', y1.shape,'\n',y1)
print('y2:', y2.shape,'\n',y2)

输出为:

沿第1维进行操作:
y1: torch.Size([2, 2, 3]) 
 tensor([[[ 1,  2,  3],
         [ 7,  8,  9]],

        [[ 4,  5,  6],
         [10, 11, 12]]])
y2: torch.Size([2, 6]) 
 tensor([[ 1,  2,  3,  7,  8,  9],
        [ 4,  5,  6, 10, 11, 12]])

从y1的输出可以看到,stack直接将x1和x2相对应的第1维的元素进行叠加,即输出为[[x1[0], x2[0]], [x1[1], x2[1]],shape由[2,3]变为[2, 2, 3]。。
从y2的输出可以看到,cat将x1和x2相对应的第1维的元素进行续接, shape由[2,3]变为[2, 6]。。拿x1[0]和x2[0]举例来说,x1[0] = [1, 2, 3], x2[0] = [7, 8, 9],对它们进行续接即可得到[1, 2, 3, 7, 8, 9]

对于torch.stack需要特别注意的是,其插入的维度是介于 0 与 待连接的张量序列数之间,stack的输入维度范围要比cat大1。对于以上例子,因为x1.shape=[2,3],所以不能torch.cat(x, dim=2)。但可以torch.stack(x, dim=2),其输出的shape为[2, 3, 2],结果为

>>>torch.stack(x, dim=2)
tensor([[[ 1,  7],
         [ 2,  8],
         [ 3,  9]],

        [[ 4, 10],
         [ 5, 11],
         [ 6, 12]]])

从以上结果可以看出,torch.stack(x, dim=2)是将x1[i][j]x2[i][j]堆叠在一起的。如x1[0][0]=1x2[0][0]=7堆叠在一起,得到[1, 7]

你可能感兴趣的:(pytorch)