Torch.cat

Torch.cat

  • Torch.cat的用法
    • 实例

Torch.cat的用法

torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起。

实例

// An highlighted block
import torch
a = torch.randn(1,3,5,5)
b = torch.randn(1,3,5,5)
print(a)
print(b)
>>> tensor([[[[ 0.4894,  0.9118, -0.7975,  0.7769, -1.0983],
          [-0.0617,  0.3230,  0.0853,  0.1426,  0.4373],
          [-0.7775, -0.4893,  0.3031, -0.5224,  0.7206],
          [ 0.0899, -1.2982,  0.3694, -0.6010,  1.0882],
          [ 0.7994, -0.0182, -0.2830, -0.1175,  2.3031]],

         [[-0.6270,  0.9806, -0.3543, -0.6706,  0.8451],
          [ 0.8559,  1.1715, -2.7926,  0.8195,  1.2003],
          [-0.9363,  0.6287,  0.8031, -1.1601,  0.5090],
          [-1.1433, -0.5224,  0.4913, -1.2035,  1.1474],
          [ 0.6201, -0.1981, -1.1308, -1.9613, -0.5917]],

         [[-0.0586, -1.1031, -0.0804, -0.2093, -2.4915],
          [ 1.1992, -1.9065, -0.9396,  0.3971, -0.1479],
          [-0.2771, -1.3371, -0.1468, -0.0249,  0.0760],
          [-0.9427, -1.0914, -0.0847, -1.0619,  0.8419],
          [ 0.8154, -0.1618, -0.0244,  0.3523, -1.3139]]]])
>>>tensor([[[[ 0.1765, -0.2371,  0.3850,  0.3014,  1.3498],
          [-0.5725,  1.1764,  0.7769,  0.7970, -0.5984],
          [-0.8498,  0.3575,  0.8842,  1.8408, -0.7673],
          [-2.0848, -2.4115, -0.1191, -1.3151, -0.2261],
          [ 0.8543,  0.0785, -0.4349, -1.3560,  0.0721]],

         [[-0.8831,  0.2914, -0.0772, -0.1918, -0.9889],
          [ 2.0799,  0.3074, -0.7013, -1.5068,  1.2838],
          [-1.1274,  0.2503,  0.9909, -1.0574,  0.1395],
          [-1.2156, -1.3117,  0.5919,  2.5695, -1.5748],
          [-0.4077,  0.8041, -1.5757, -0.0711, -0.6129]],

         [[-1.6921,  0.0097, -0.3866,  0.5965, -1.3929],
          [ 0.2597, -0.6740,  0.3119, -1.9251, -1.6731],
          [ 0.0244,  0.7889, -0.1629, -0.9620, -0.2372],
          [-1.5149,  0.4383, -1.5867, -1.0003,  0.0335],
          [ 0.1328, -1.6683, -1.3638,  0.0362, -0.4178]]]])
c = torch.cat([a,b], dim=1) #concatnate on dim=1
print(c)
>>>tensor([[[[ 0.4894,  0.9118, -0.7975,  0.7769, -1.0983],
          [-0.0617,  0.3230,  0.0853,  0.1426,  0.4373],
          [-0.7775, -0.4893,  0.3031, -0.5224,  0.7206],
          [ 0.0899, -1.2982,  0.3694, -0.6010,  1.0882],
          [ 0.7994, -0.0182, -0.2830, -0.1175,  2.3031]],
         [[-0.6270,  0.9806, -0.3543, -0.6706,  0.8451],
          [ 0.8559,  1.1715, -2.7926,  0.8195,  1.2003],
          [-0.9363,  0.6287,  0.8031, -1.1601,  0.5090],
          [-1.1433, -0.5224,  0.4913, -1.2035,  1.1474],
          [ 0.6201, -0.1981, -1.1308, -1.9613, -0.5917]],
         [[-0.0586, -1.1031, -0.0804, -0.2093, -2.4915],
          [ 1.1992, -1.9065, -0.9396,  0.3971, -0.1479],
          [-0.2771, -1.3371, -0.1468, -0.0249,  0.0760],
          [-0.9427, -1.0914, -0.0847, -1.0619,  0.8419],
          [ 0.8154, -0.1618, -0.0244,  0.3523, -1.3139]],
         [[ 0.1765, -0.2371,  0.3850,  0.3014,  1.3498],
          [-0.5725,  1.1764,  0.7769,  0.7970, -0.5984],
          [-0.8498,  0.3575,  0.8842,  1.8408, -0.7673],
          [-2.0848, -2.4115, -0.1191, -1.3151, -0.2261],
          [ 0.8543,  0.0785, -0.4349, -1.3560,  0.0721]],
         [[-0.8831,  0.2914, -0.0772, -0.1918, -0.9889],
          [ 2.0799,  0.3074, -0.7013, -1.5068,  1.2838],
          [-1.1274,  0.2503,  0.9909, -1.0574,  0.1395],
          [-1.2156, -1.3117,  0.5919,  2.5695, -1.5748],
          [-0.4077,  0.8041, -1.5757, -0.0711, -0.6129]],
         [[-1.6921,  0.0097, -0.3866,  0.5965, -1.3929],
          [ 0.2597, -0.6740,  0.3119, -1.9251, -1.6731],
          [ 0.0244,  0.7889, -0.1629, -0.9620, -0.2372],
          [-1.5149,  0.4383, -1.5867, -1.0003,  0.0335],
          [ 0.1328, -1.6683, -1.3638,  0.0362, -0.4178]]]])


c.shape
Out[8]: torch.Size([1, 6, 5, 5])


c = torch.cat([a,b], dim=0)   # concatnate on dim=0
c.shape
Out[10]: torch.Size([2, 3, 5, 5])

根据实验可以看出,torch.cat([a,b],dim=1) 是在channel维度上对ab进行了concatnate。同理,dim=0就是对batch维度上进行拼接,此时的shape就会变成torch.Size([2,3, 5, 5]).

你可能感兴趣的:(CNN,语义分割)