pytorch张量的拆分与合并

在 PyTorch 中,对张量 (Tensor) 进行拆分通常会用到两个函数:

 torch.split [按块大小拆分张量]
 torch.chunk [按块数拆分张量]
而对张量 (Tensor) 进行拼接通常会用到另外两个函数:

 torch.cat [按已有维度拼接张量]
 torch.stack [按新维度拼接张量]

张量的拆分

torch.split() 按照块的大小进行划分

import torch
#定义一个四维张量
x = torch.randn(1, 64,32,32)
#按照维度1 即64 按照块大小为4进行划分,一共划分了16个块,返回的是一个列表
o1=torch.split(x, 4, dim = 1)
print(x.shape)
print(len(o1))
print(o1[1].shape)

torch.chunk() 按照块数进行划分

import torch
#定义一个四维张量
x = torch.randn(1, 64,32,32)
#按照维度1 即64 按照块的数目为4进行划分,划分每个块的大小是16,返回的是一个列表
o1=torch.chunk(x, 4, dim = 1)
print(x.shape)
print(len(o1))
print(o1[1].shape)

张量的拼接

torch.cat(tensor,dim)  在已有的维度上进行拼接

import torch
#定义两个个四维张量
x = torch.randn(1, 64,32,32)
y = torch.randn(1,64, 32,32)

#在维度1上进行拼接
out=torch.cat([x,y],dim=1)
print(out.shape)

torch.stack()   按照新的维度进行拼接

import torch
#定义两个个四维张量
x = torch.randn(1, 64,32,32)
y = torch.randn(1,64, 32,32)

#在维度0进行拼接
out=torch.stack([x,y],dim=0)
print(out.shape)


 

你可能感兴趣的:(pytorch,pytorch,深度学习,人工智能)