torch.cat 和 torch.stack的区别

torch.cat(seq,dim=0,out=None) # 沿着dim连接seq中的tensor, 所有的tensor必须有相同的size或为empty, 其相反的操作为 torch.split() 和torch.chunk()
torch.stack(seq, dim=0, out=None) #同上

#注: .cat 和 .stack的区别在于 cat会增加现有维度的值,可以理解为续接,stack会新加增加一个维度,可以
理解为叠加
>>> a=torch.Tensor([1,2,3])
>>> torch.stack((a,a)).size()
torch.size(2,3)
>>> torch.cat((a,a)).size()
torch.size(6)

 

你可能感兴趣的:(torch.cat 和 torch.stack的区别)