【无标题】

torch.cat()的用法:
一般为torch.cat((x1,x2),0)或torch.cat((x1,x2),两种形式),0表示按行连接,1表示按列连接,具体列子从pytorch官网上粘贴过来的,具体请参https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])

此笔记为个人记录用,侵删。

你可能感兴趣的:(pytorch,深度学习,神经网络)