Pytorch:Tensor的合并与分割

关键方法一览

方法 作用 区别
cat 合并 保持原有维度的数量
stack 合并 原有维度数量加1
split 分割 按照长度去分割
chunk 分割 等分

要点细述

cat

catconcatenate(连接)的缩写,而不是指(猫)。作用是把2个tensor按照特定的维度连接起来。
要求:除被拼接的维度外,其他维度必须相同

Code Demo

import torch
a=torch.randn(3,4) #随机生成一个shape(3,4)的tensort
b=torch.randn(2,4) #随机生成一个shape(2,4)的tensor

torch.cat([a,b],dim=0) 
#返回一个shape(5,4)的tensor
#把a和b拼接成一个shape(5,4)的tensor,
#可理解为沿着行增加的方向(即纵向)拼接

stack

stack会增加一个新的维度,来表示拼接后的2个tensor,直观些理解的话,咱们不妨把一个2维的tensor理解成一张长方形的纸张,cat相当于是把两张纸缝合在一起,形成一张更大的纸,而stack相当于是把两张纸上下堆叠在一起。
要求:两个tensor拼接前的形状完全一致

Code Demo

a=torch.randn(3,4)
b=torch.randn(3,4)

c=torch.stack([a,b],dim=0)
#返回一个shape(2,3,4)的tensor,新增的维度2分别指向a和b

d=torch.stack([a,b],dim=1)
#返回一个shape(3,2,4)的tensor,新增的维度2分别指向相应的a的第i行和b的第i行

助记:
这里的关键词参数dim的理解和cat方法中有些区别。

cat方法中可以理解为原tensor的维度,dim=0,就是沿着原来的0轴进行拼接,dim=1,就是沿着原来的1轴进行拼接。

stack方法中的dim则是指向新增维度的位置,dim=0,就是在新形成的tensor的维度的第0个位置新插入维度

split

split是根据长度去拆分tensor

Code Demo

a=torch.randn(3,4)

a.split([1,2],dim=0)
#把维度0按照长度[1,2]拆分,形成2个tensor,
#shape(1,4)和shape(2,4)

a.split([2,2],dim=1)
#把维度1按照长度[2,2]拆分,形成2个tensor,
#shape(3,2)和shape(3,2)

chunk

chunk可以理解为均等分的split,但是当维度长度不能被等分份数整除时,虽然不会报错,但可能结果与预期的不一样,建议只在可以被整除的情况下运用

Code Demo

a=torch.randn(4,6)

a.chunk(2,dim=0)
#返回一个shape(2,6)的tensor
a.chunk(2,dim=1)
#返回一个shape(4,3)的tensor

你可能感兴趣的:(Pytorch:Tensor的合并与分割)