torch.cat((A,B),dim=-1)

>>> import torch
>>> x = torch.randn(2,3)
>>> x
tensor([[-0.8168, -1.9389,  0.0781],
        [ 0.3570, -1.3199, -0.1600]])
>>> torch.cat((x,x),0)
tensor([[-0.8168, -1.9389,  0.0781],
        [ 0.3570, -1.3199, -0.1600],
        [-0.8168, -1.9389,  0.0781],
        [ 0.3570, -1.3199, -0.1600]])
>>> torch.cat((x,x),-1)
tensor([[-0.8168, -1.9389,  0.0781, -0.8168, -1.9389,  0.0781],
        [ 0.3570, -1.3199, -0.1600,  0.3570, -1.3199, -0.1600]])

dim=-1是两块直接拼接

你可能感兴趣的:(深度学习,人工智能,python)