pytorch:torch.cat

x = torch.rand(16, 85, 768) # 维度数(2,1,0)
y = torch.rand(16, 85, 768)

如果要左右拼接,dim要等于2(最左边的维度2),列数增加:768*2
z = torch.cat((x,y),2)
print(z.shape) # torch.Size([16, 85, 1536])

z = torch.cat((x,y),1)
print(z.shape) # torch.Size([16, 170, 768])

如果要左上下拼接,dim要等于0(最右边的维度0),行数增加:16*2
z = torch.cat((x,y),0)
print(z.shape) # torch.Size([32, 85, 768])

即:上下拼接要列数相同,左右拼接要行数相同。

你可能感兴趣的:(笔记,pytorch,深度学习,python)