torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。
// An highlighted block
import torch
a = torch.randn(1,3,5,5)
b = torch.randn(1,3,5,5)
print(a)
print(b)
>>> tensor([[[[ 0.4894, 0.9118, -0.7975, 0.7769, -1.0983],
[-0.0617, 0.3230, 0.0853, 0.1426, 0.4373],
[-0.7775, -0.4893, 0.3031, -0.5224, 0.7206],
[ 0.0899, -1.2982, 0.3694, -0.6010, 1.0882],
[ 0.7994, -0.0182, -0.2830, -0.1175, 2.3031]],
[[-0.6270, 0.9806, -0.3543, -0.6706, 0.8451],
[ 0.8559, 1.1715, -2.7926, 0.8195, 1.2003],
[-0.9363, 0.6287, 0.8031, -1.1601, 0.5090],
[-1.1433, -0.5224, 0.4913, -1.2035, 1.1474],
[ 0.6201, -0.1981, -1.1308, -1.9613, -0.5917]],
[[-0.0586, -1.1031, -0.0804, -0.2093, -2.4915],
[ 1.1992, -1.9065, -0.9396, 0.3971, -0.1479],
[-0.2771, -1.3371, -0.1468, -0.0249, 0.0760],
[-0.9427, -1.0914, -0.0847, -1.0619, 0.8419],
[ 0.8154, -0.1618, -0.0244, 0.3523, -1.3139]]]])
>>>tensor([[[[ 0.1765, -0.2371, 0.3850, 0.3014, 1.3498],
[-0.5725, 1.1764, 0.7769, 0.7970, -0.5984],
[-0.8498, 0.3575, 0.8842, 1.8408, -0.7673],
[-2.0848, -2.4115, -0.1191, -1.3151, -0.2261],
[ 0.8543, 0.0785, -0.4349, -1.3560, 0.0721]],
[[-0.8831, 0.2914, -0.0772, -0.1918, -0.9889],
[ 2.0799, 0.3074, -0.7013, -1.5068, 1.2838],
[-1.1274, 0.2503, 0.9909, -1.0574, 0.1395],
[-1.2156, -1.3117, 0.5919, 2.5695, -1.5748],
[-0.4077, 0.8041, -1.5757, -0.0711, -0.6129]],
[[-1.6921, 0.0097, -0.3866, 0.5965, -1.3929],
[ 0.2597, -0.6740, 0.3119, -1.9251, -1.6731],
[ 0.0244, 0.7889, -0.1629, -0.9620, -0.2372],
[-1.5149, 0.4383, -1.5867, -1.0003, 0.0335],
[ 0.1328, -1.6683, -1.3638, 0.0362, -0.4178]]]])
c = torch.cat([a,b], dim=1) #concatnate on dim=1
print(c)
>>>tensor([[[[ 0.4894, 0.9118, -0.7975, 0.7769, -1.0983],
[-0.0617, 0.3230, 0.0853, 0.1426, 0.4373],
[-0.7775, -0.4893, 0.3031, -0.5224, 0.7206],
[ 0.0899, -1.2982, 0.3694, -0.6010, 1.0882],
[ 0.7994, -0.0182, -0.2830, -0.1175, 2.3031]],
[[-0.6270, 0.9806, -0.3543, -0.6706, 0.8451],
[ 0.8559, 1.1715, -2.7926, 0.8195, 1.2003],
[-0.9363, 0.6287, 0.8031, -1.1601, 0.5090],
[-1.1433, -0.5224, 0.4913, -1.2035, 1.1474],
[ 0.6201, -0.1981, -1.1308, -1.9613, -0.5917]],
[[-0.0586, -1.1031, -0.0804, -0.2093, -2.4915],
[ 1.1992, -1.9065, -0.9396, 0.3971, -0.1479],
[-0.2771, -1.3371, -0.1468, -0.0249, 0.0760],
[-0.9427, -1.0914, -0.0847, -1.0619, 0.8419],
[ 0.8154, -0.1618, -0.0244, 0.3523, -1.3139]],
[[ 0.1765, -0.2371, 0.3850, 0.3014, 1.3498],
[-0.5725, 1.1764, 0.7769, 0.7970, -0.5984],
[-0.8498, 0.3575, 0.8842, 1.8408, -0.7673],
[-2.0848, -2.4115, -0.1191, -1.3151, -0.2261],
[ 0.8543, 0.0785, -0.4349, -1.3560, 0.0721]],
[[-0.8831, 0.2914, -0.0772, -0.1918, -0.9889],
[ 2.0799, 0.3074, -0.7013, -1.5068, 1.2838],
[-1.1274, 0.2503, 0.9909, -1.0574, 0.1395],
[-1.2156, -1.3117, 0.5919, 2.5695, -1.5748],
[-0.4077, 0.8041, -1.5757, -0.0711, -0.6129]],
[[-1.6921, 0.0097, -0.3866, 0.5965, -1.3929],
[ 0.2597, -0.6740, 0.3119, -1.9251, -1.6731],
[ 0.0244, 0.7889, -0.1629, -0.9620, -0.2372],
[-1.5149, 0.4383, -1.5867, -1.0003, 0.0335],
[ 0.1328, -1.6683, -1.3638, 0.0362, -0.4178]]]])
c.shape
Out[8]: torch.Size([1, 6, 5, 5])
c = torch.cat([a,b], dim=0) # concatnate on dim=0
c.shape
Out[10]: torch.Size([2, 3, 5, 5])
根据实验可以看出,torch.cat([a,b],dim=1)
是在channel维度上对a和b进行了concatnate。同理,dim=0
就是对batch维度上进行拼接,此时的shape就会变成torch.Size([2,3, 5, 5])
.