pytorch 基本操作(二)——拆分、合并

pytorch 基本操作(二)——拆分、合并

  • cat
  • stack
  • split
  • chunk
  • 参考文献:

cat

cat就是简单的合并操作,假设我们现在有两个tensors,一个维度为4x32x8,另一个为5x32x8,我们将这两个tensor在第一个维度上合并,如果你想在别的维度上合并,只需要改变dim的值就可以了,需要注意的是dim的指向值其他的维度必须相同。

a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
torch.cat([a,b], dim = 0).shape
# out:
# torch.Size([9, 32, 8])

stack

stack是建立一个新的维度,要求tensor的维度相同。假设我们现在有两个tensors,维度为4x32x8。

a = torch.rand(4,32,8)
b = torch.rand(4,32,8)
torch.stack([a,b],dim = 0).shape
# out:
# torch.Size([2, 4, 32, 8])

这样就在第一个位置将两个tensors堆叠在了一起,同理dim是调整堆叠的位置的。

split

相对而言split就是来拆分tensor的,假设一个4x32x8的tensor,我们把第一个维度拆分,拆成1和3。

a = torch.rand(4,32,8)
a1, a2 = a.split([1,3],dim = 0)
a1.shape
# out:
# torch.Size([1, 32, 8])

a2.shape
# out:
# torch.Size([3, 32, 8])

如果说把第一个维度按间隔为3进行分割。

a1, a2 = a.split(3,dim = 0)
a1.shape
# out:
# torch.Size([3, 32, 8])
a2.shape
# out:
# torch.Size([1, 32, 8])

chunk

chunk函数会将维度平均差分成几个部分,如果我们把4x32x8的tensor的第一个维度平均拆分为两个。

a = torch.rand(4,32,8)
a1, a2 = a.chunk(2,dim = 0)
a1.shape
# out:
# torch.Size([2, 32, 8])

基本的拆分函数就是这样,其余还有更多的参考帮助文档。

参考文献:

[1] https://github.com/irobbwu/pytorch-intro/blob/main/02.basic.ipynb

你可能感兴趣的:(深度学习,pytorch,stack)