torch.nn.cat()函数

csdn上一些对torch.nn.cat()中dim参数描述存在误导性。很多讲解只是以二维tensor为例,对于初学者很不友好,可能会误以为dim=0就是以行拼接,dim=1就是按列拼接。这种理解在做深度学习时,带来的错误可能是致命性的,torch的输入往往是(batch,channel,H,W),这时候dim=0就是按batch拼接,dim=1就是按channel拼接 。。。

demo:

>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(2, 3)
>>> a
tensor([[-0.2588, -0.0599, -1.7341],
        [-0.5032, -0.5763, -0.5034]])
>>> b
tensor([[-0.6648,  1.5979,  1.5730],
        [ 0.4543, -0.4951,  1.3021]])
>>> torch.cat([a, b], 0)
tensor([[-0.2588, -0.0599, -1.7341],
        [-0.5032, -0.5763, -0.5034],
        [-0.6648,  1.5979,  1.5730],
        [ 0.4543, -0.4951,  1.3021]])
>>> torch.cat([a, b], 1)
tensor([[-0.2588, -0.0599, -1.7341, -0.6648,  1.5979,  1.5730],
        [-0.5032, -0.5763, -0.5034,  0.4543, -0.4951,  1.3021]])
>>> c = torch.randn(2, 3, 3)
>>> d = torch.randn(2, 3, 3)
>>> c
tensor([[[-0.5136,  0.9830,  1.0458],
         [ 0.2047, -0.9562, -0.5320],
         [ 0.5590,  1.8764,  0.0365]],

        [[ 1.2450, -0.6891, -0.6812],
         [ 0.1304,  0.8681, -0.4488],
         [ 0.1647,  0.0277,  0.4221]]])
>>> d
tensor([[[ 0.5374, -1.3877, -1.2990],
         [ 1.0750,  2.6673, -0.1455],
         [ 0.2429, -0.2498, -0.4213]],

        [[-0.0743, -0.6336, -1.1928],
         [ 0.2520, -1.2656,  1.0350],
         [-1.4310,  2.7352,  0.0715]]])
>>> torch.cat([c, d], 2)
tensor([[[-0.5136,  0.9830,  1.0458,  0.5374, -1.3877, -1.2990],
         [ 0.2047, -0.9562, -0.5320,  1.0750,  2.6673, -0.1455],
         [ 0.5590,  1.8764,  0.0365,  0.2429, -0.2498, -0.4213]],

        [[ 1.2450, -0.6891, -0.6812, -0.0743, -0.6336, -1.1928],
         [ 0.1304,  0.8681, -0.4488,  0.2520, -1.2656,  1.0350],
         [ 0.1647,  0.0277,  0.4221, -1.4310,  2.7352,  0.0715]]])

 

你可能感兴趣的:(Crack-Pytorch,深度学习,deep,learning)