Pytorch基础 - 5. torch.cat() 和 torch.stack()

目录

1. torch.cat(tensors, dim)

2. torch.stack(tensors, dim)

3. 两者不同


torch.cat() 和 torch.stack()常用来进行张量的拼接,在神经网络里经常用到。且前段时间有一个面试官也问到了这个知识点,虽然内容很小很细,但需要了解。

1. torch.cat(tensors, dim)

  • tensors:待拼接的多个张量,可用list, tuple表示
  • dim:待拼接的维度,默认是0
  • 注意:tensors里不同张量对应的待拼接维度的size可以不一致,但是其他维度的size要保持一致。如代码中待拼接维度是0,x和y对应的维度0上的值不一样,但是其他维度上的值(维度1上的值)要保持一致,即都为4,否则会报错。

Pytorch基础 - 5. torch.cat() 和 torch.stack()_第1张图片

示例:新生成的tensor在dim=0这个维度进行了拼接,即 3 + 2 = 5,剩余维度保持不变

x = torch.rand(3, 4)
y = torch.rand(2, 4)
xy = torch.cat([x, y], dim=0)   
print(xy.shape)   # torch.Size([5, 4])

2. torch.stack(tensors, dim)

  • tensors:待拼接的多个张量,可用list, tuple表示
  • dim:待拼接的维度,默认是0
  • 注意:tensors里所有张量的维度要保持一致,否则会报错

Pytorch基础 - 5. torch.cat() 和 torch.stack()_第2张图片

x = torch.rand(7, 4)
y = torch.rand(7, 4)
z = torch.rand(7, 4)
xy = torch.stack([x, y, z])
print(xy.shape)   # torch.Size([3, 7, 4])

3. 两者不同

从上面的代码结果可看出两者区别:

  • torch.cat会在dim的维度上进行合并,不会扩展出新的维度
  • torch.stack则会在dim的维度上拓展出一个新的维度,然后进行拼接,该维度的大小为tensors的个数

你可能感兴趣的:(#,Pytorch操作,pytorch,深度学习,python)