【pytorch】torch.cat和torch.stack的区别

【pytorch】torch.cat和torch.stack的区别

torch.cat

import torch
a=torch.randn((1,3,4,4)) #假设代表了[N,c,w,h]

b=torch.cat((a,a)) #维度默认是0
# (2, 3, 4, 4)

c=torch.cat((a,a),dim=1)
# (1, 6, 4, 4)

接下来看一些维度不同的

import torch
a=torch.randn((1,3,4,4)) #[N,c,w,h]
b=torch.randn((1,3,5,5))
c=torch.cat((a,b))
'''
invalid argument 0: Sizes of tensors must match except in dimension 0.
'''

也就是说,要拼接的tensor除了dim参数,其他的维度都必须要一致

torch.stack

import torch
a=torch.randn((1,3,4,4)) #[N,c,w,h]
b=torch.stack((a,a))
# (2, 1, 3, 4, 4)

c=torch.stack((a,a),1)
# (1, 2, 3, 4, 4)

d=torch.stack((a,a),2)
# (1, 3, 2, 4, 4)

可以看出来,首先torch.stack会使维度的个数增加,那具体是怎么沿维度进行堆叠的呢?

import torch
a=torch.arange(1,7).reshape((3,2))
#tensor([[1, 2],
#        [3, 4],
#        [5, 6]])
b=torch.arange(10,70,10).reshape((3,2))
#tensor([[10, 20],
#        [30, 40],
#        [50, 60]])
c=torch.arange(100,700,100).reshape((3,2))
#tensor([[100, 200],
#        [300, 400],
#        [500, 600]])
d=torch.stack((a,b,c)) #(3, 3, 2)
#tensor([[[  1,   2],
#         [  3,   4],
#         [  5,   6]],

#        [[ 10,  20],
#         [ 30,  40],
#         [ 50,  60]],

#        [[100, 200],
#         [300, 400],
#         [500, 600]]])
e=torch.stack((a,b,c),1) #(3, 3, 2)
#tensor([[[  1,   2],
#         [ 10,  20],
#         [100, 200]],

#        [[  3,   4],
#         [ 30,  40],
#         [300, 400]],

#        [[  5,   6],
#         [ 50,  60],
#         [500, 600]]])
f=torch.stack((a,b,c),2) #(3, 2, 3)
#tensor([[[  1,  10, 100],
#         [  2,  20, 200]],

#        [[  3,  30, 300],
#         [  4,  40, 400]],

#        [[  5,  50, 500],
#         [  6,  60, 600]]])

首先会将要拼接的tensor按dim增加,比如说对于d的操作来说,先将a,b,c的维度变为(1,3,3),然后再按dim进行torch.cat的操作。

你可能感兴趣的:(pytorch)