Pytorch 数组的维度拼接 --- torch.cat() 与 torch.stack() 方法

1、torch.cat() 方法

import torch

a = torch.randn((2, 3, 3))
b = torch.randn((1, 3, 3))
# 指定 a, b 数组,在 0 维度拼接。
# 需要注意的是,除了指定的这一维度可以不同外,其他的维度的大小必须相同
c = torch.cat((a,b), 0)
print(c.shape)

2、torch.stack() 方法

import torch

a=torch.arange(12).reshape(3,4)
b=torch.ones(12).reshape(3,4)
c=torch.stack((a,b),dim=0)
d=torch.stack((a,b),dim=1)
e=torch.stack((a,b),dim=2)
# dim最大可到输入数组的维数,即a、b的维数
# 相当于先将多个n维数组进行扩维操作,然后再拼接为一个n+1维的数组
# a, b 数组的维度必须相同,在哪个维度拼接,即在扩增哪个维度
print(c.shape)
print(d.shape)
print(e.shape)

你可能感兴趣的:(一些小代码,参数解释,visual,studio,ide,visualstudio)