cat即concatenate的意思,是指沿着已有的数据的 某一维度进行拼接,操作后数据的总维数不变,在进行拼接时,除了拼 接的维度之外,其他维度必须相同。
而torch.stack()函数指新增维度,并按照指定的维度进行叠加
import torch
a = torch.randperm(10)
a = a.reshape([2, 5])
b = torch.Tensor([[10, 11, 15, 12, 13], [5, 4, 3, 2, 1]])
a = a.type_as(b) # 将a,b数据类型转换一致
c = a + b # 不转换数据类型,print(a+b)会报错
torch.cat([a,b])
# torch.cat([a,b], 0)
Out[5]:
tensor([[ 6., 9., 8., 1., 4.],
[ 5., 2., 3., 0., 7.],
[10., 11., 15., 12., 13.],
[ 5., 4., 3., 2., 1.]])
torch.cat([a,b],1)
Out[6]:
tensor([[ 6., 9., 8., 1., 4., 10., 11., 15., 12., 13.],
[ 5., 2., 3., 0., 7., 5., 4., 3., 2., 1.]])
# 以第0维进行stack,叠加的基本单位为序列本身,即a与b,因此输出[a, b],
torch.stack([a,b], 0)
Out[7]:
tensor([[[ 6., 9., 8., 1., 4.],
[ 5., 2., 3., 0., 7.]],
[[10., 11., 15., 12., 13.],
[ 5., 4., 3., 2., 1.]]])
# 以第1维进行stack,叠加的基本单位为每一行
torch.stack([a,b], 1)
Out[64]:
tensor([[[ 6., 9., 8., 1., 4.],
[10., 11., 15., 12., 13.]],
[[ 5., 2., 3., 0., 7.],
[ 5., 4., 3., 2., 1.]]])
# 以第2维进行stack,叠加的基本单位为每一行的每一个元素
torch.stack([a,b], 2)
Out[67]:
tensor([[[ 6., 10.],
[ 9., 11.],
[ 8., 15.],
[ 1., 12.],
[ 4., 13.]],
[[ 5., 5.],
[ 2., 4.],
[ 3., 3.],
[ 0., 2.],
[ 7., 1.]]])