组合之torch.cat()和 torch.stack()

cat即concatenate的意思,是指沿着已有的数据的 某一维度进行拼接,操作后数据的总维数不变,在进行拼接时,除了拼 接的维度之外,其他维度必须相同。
而torch.stack()函数指新增维度,并按照指定的维度进行叠加

import torch
a = torch.randperm(10)
a = a.reshape([2, 5])
b = torch.Tensor([[10, 11, 15, 12, 13], [5, 4, 3, 2, 1]])
a = a.type_as(b)  # 将a,b数据类型转换一致
c = a + b # 不转换数据类型,print(a+b)会报错
torch.cat([a,b])
# torch.cat([a,b], 0)
Out[5]: 
tensor([[ 6.,  9.,  8.,  1.,  4.],
        [ 5.,  2.,  3.,  0.,  7.],
        [10., 11., 15., 12., 13.],
        [ 5.,  4.,  3.,  2.,  1.]])


torch.cat([a,b],1)
Out[6]: 
tensor([[ 6.,  9.,  8.,  1.,  4., 10., 11., 15., 12., 13.],
        [ 5.,  2.,  3.,  0.,  7.,  5.,  4.,  3.,  2.,  1.]])
# 以第0维进行stack,叠加的基本单位为序列本身,即a与b,因此输出[a, b],
torch.stack([a,b], 0)
Out[7]: 
tensor([[[ 6.,  9.,  8.,  1.,  4.],
         [ 5.,  2.,  3.,  0.,  7.]],

        [[10., 11., 15., 12., 13.],
         [ 5.,  4.,  3.,  2.,  1.]]])

# 以第1维进行stack,叠加的基本单位为每一行
torch.stack([a,b], 1)
Out[64]: 
tensor([[[ 6.,  9.,  8.,  1.,  4.],
         [10., 11., 15., 12., 13.]],

        [[ 5.,  2.,  3.,  0.,  7.],
         [ 5.,  4.,  3.,  2.,  1.]]])

# 以第2维进行stack,叠加的基本单位为每一行的每一个元素
torch.stack([a,b], 2)
Out[67]: 
tensor([[[ 6., 10.],
         [ 9., 11.],
         [ 8., 15.],
         [ 1., 12.],
         [ 4., 13.]],

        [[ 5.,  5.],
         [ 2.,  4.],
         [ 3.,  3.],
         [ 0.,  2.],
         [ 7.,  1.]]])

你可能感兴趣的:(组合之torch.cat()和 torch.stack())