pytorch-------Tensor的合并与拆分(6)

合并

cat

import torch
import numpy as np

##############cat########################
#dim一样。对于cat的dim可以不一样,其他的必须一样
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
print(torch.cat([a,b],dim=0).shape)
print("####################")
torch.Size([9, 32, 8])

stack

#############stack#############
#俩个tensor维度必须完全一样,扩展一个维度
#俩个班级,每个班级32个学生,8门课
a = torch.rand(32,8)
b = torch.rand(32,8)
print(torch.stack([a,b],dim=0).shape)
torch.Size([2, 32, 8])

拆分

split:by len

#1.根据长度拆分,2.根据个数拆分
##################split:by len######################
a = torch.rand(2,32,8)
b = torch.rand(3,32,8)
a1,a2 = a.split(1,dim=0)###1这里是长度1
print(a1.shape)
print(a2.shape)
b1,b2 = b.split([2,1],dim=0)
print(b1.shape)
print(b2.shape)
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])
torch.Size([2, 32, 8])
torch.Size([1, 32, 8])

chunk:by num

################chunk:by num#########
a = torch.rand(2,32,8)
a1,a2 = a.chunk(2,dim=0)##分为多少个
print(a1.shape)
print(a2.shape)
torch.Size([1, 32, 8])
torch.Size([1, 32, 8])

你可能感兴趣的:(pytorch)