Pytorch学习-张量的合并和分割

可以用torch.cat方法和torch.stack方法将多个张量合并,可以用torch.split方法把一个张量分割成多个张量。

torch.cat和torch.stack有略微的区别,torch.cat是连接,不会增加维度,而torch.stack是堆叠,会增加维度。

a = torch.tensor([[1, 2],[3,4]])
b = torch.tensor([[5,6],[7,8]])
c = torch.tensor([[9,10],[11,12]])

abc_cat = torch.cat([a,b,c],axis = 0)
print(abc_cat.shape)
print(abc_cat)

abc_cat1 = torch.cat([a,b,c],axis = 1)
print(abc_cat1.shape)
print(abc_cat1)

Pytorch学习-张量的合并和分割_第1张图片

abc_stack = torch.stack([a,b,c],axis = 0)
print(abc_stack.shape)
print(abc_stack)

abc_stack1 = torch.stack([a,b,c],axis = 1)
print(abc_stack1.shape)
print(abc_stack1)

Pytorch学习-张量的合并和分割_第2张图片
torch.split是torch.cat的逆运算,可以指定分割份数平均分割,也可以通过指定每份的记录数量进行分割。

print(abc_cat)

a,b,c = torch.split(abc_cat,split_size_or_sections=2,dim = 0)
print(a)
print(b)
print(c)

Pytorch学习-张量的合并和分割_第3张图片

a,b,c = torch.split(abc_cat,split_size_or_sections=[4,1,1],dim = 0)
print(a)
print(b)
print(c)

Pytorch学习-张量的合并和分割_第4张图片

你可能感兴趣的:(Pytorch,python)