Pytorch tensor的拼接与拆分

tensor的拼接与拆分

cat函数

torch.cat(tensorsdim=0*out=None) → Tensor

在指定的维度dim上,连接给定的一组张量。所有张量必须具有相同的形状(连接维度除外)。

例如在0维度上合并df1和df2向量:

Pytorch tensor的拼接与拆分_第1张图片

 在1维度上df1和df2向量:

Pytorch tensor的拼接与拆分_第2张图片

 例子:成绩单的合并

【班级1~4 学生 得分】

【班级5~9 学生 得分】

a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
c = torch.cat([a,b],dim=0)
c.shape()
#[9,32,8]

stack函数

torch.stack(tensorsdim=0*out=None) → Tensor

要合并的两个tensor必须有相同的shape,会新添加一个维度,两个tensor的会沿着新增加的维度合并

例子:

一班:【32个学生 每个学生8门课程】

二班:【32个学生 每个学生8门课程】

stack之后变为【两个班级 每个班级32个学生 每个学生有8门课程】

a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape
#[2 32 8]

split函数

torch.split(tensorsplit_size_or_sectionsdim=0)

将张量分成块。每个块都是原始张量的视图。

  • 如果 split_size_or_sections 是整数类型,那么张量将被拆分为大小相等的块(如果可能)。如果沿给定维度 dim 的张量大小不能被 split_size 整除,则最后一个块会最小。
  • 如果 split_size_or_sections 是一个列表,那么张量将根据 split_size_or_sections 分为大小为 len(split_size_or_sections) 的块。

参数

  • tensor (Tensor) – 要拆分的张量

  • split_size_or_sections (int) or (list(int)) – 单个块的大小或每个块的大小列表

  • dim (int) – 沿其拆分张量的维度

举例说明:

>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

chunk函数

torch.chunk(inputchunksdim=0) → List of Tensors

将张量拆分为特定数量的块。每个块都是输入张量的一个视图。 如果沿给定维度 dim 的张量大小不能被块整除,则最后一个块将更小。

Parameters

  • input (Tensor) – 要拆分的张量

  • chunks (int) – 要返回的块数

  • dim (int) – 沿着其拆分张量的维度

例子1:

a = torch.rand(6,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])

例子2:

a = torch.rand(5,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([1, 32, 8])

例子3:

a = torch.rand(5,32,8)
b,c= torch.chunk(a,2,dim=0)
print(b.shape)
print(c.shape)

#torch.Size([3, 32, 8])
#torch.Size([2, 32, 8])

  

你可能感兴趣的:(深度学习,pytorch,pytorch,深度学习,python)