csdn上一些对torch.nn.cat()中dim参数描述存在误导性。很多讲解只是以二维tensor为例,对于初学者很不友好,可能会误以为dim=0就是以行拼接,dim=1就是按列拼接。这种理解在做深度学习时,带来的错误可能是致命性的,torch的输入往往是(batch,channel,H,W),这时候dim=0就是按batch拼接,dim=1就是按channel拼接 。。。
demo:
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(2, 3)
>>> a
tensor([[-0.2588, -0.0599, -1.7341],
[-0.5032, -0.5763, -0.5034]])
>>> b
tensor([[-0.6648, 1.5979, 1.5730],
[ 0.4543, -0.4951, 1.3021]])
>>> torch.cat([a, b], 0)
tensor([[-0.2588, -0.0599, -1.7341],
[-0.5032, -0.5763, -0.5034],
[-0.6648, 1.5979, 1.5730],
[ 0.4543, -0.4951, 1.3021]])
>>> torch.cat([a, b], 1)
tensor([[-0.2588, -0.0599, -1.7341, -0.6648, 1.5979, 1.5730],
[-0.5032, -0.5763, -0.5034, 0.4543, -0.4951, 1.3021]])
>>> c = torch.randn(2, 3, 3)
>>> d = torch.randn(2, 3, 3)
>>> c
tensor([[[-0.5136, 0.9830, 1.0458],
[ 0.2047, -0.9562, -0.5320],
[ 0.5590, 1.8764, 0.0365]],
[[ 1.2450, -0.6891, -0.6812],
[ 0.1304, 0.8681, -0.4488],
[ 0.1647, 0.0277, 0.4221]]])
>>> d
tensor([[[ 0.5374, -1.3877, -1.2990],
[ 1.0750, 2.6673, -0.1455],
[ 0.2429, -0.2498, -0.4213]],
[[-0.0743, -0.6336, -1.1928],
[ 0.2520, -1.2656, 1.0350],
[-1.4310, 2.7352, 0.0715]]])
>>> torch.cat([c, d], 2)
tensor([[[-0.5136, 0.9830, 1.0458, 0.5374, -1.3877, -1.2990],
[ 0.2047, -0.9562, -0.5320, 1.0750, 2.6673, -0.1455],
[ 0.5590, 1.8764, 0.0365, 0.2429, -0.2498, -0.4213]],
[[ 1.2450, -0.6891, -0.6812, -0.0743, -0.6336, -1.1928],
[ 0.1304, 0.8681, -0.4488, 0.2520, -1.2656, 1.0350],
[ 0.1647, 0.0277, 0.4221, -1.4310, 2.7352, 0.0715]]])