Pytorch疑难小实验:理解torch.cat()在不同维度下的连接方式

import torch 

def printt(x,a=""):
    print(x)
    print("{}.dim{}".format(a,x.dim()))
    print("{}.shape{}".format(a,x.shape))

x = torch.arange(48).reshape(2,2,3,4)
printt(x,"x")

y = torch.ones(48).reshape(2,2,3,4)
printt(y,"y")

a = torch.cat((x,y),dim = 0)
printt(a,"a")

b = torch.cat((x,y),dim = 1)
printt(b,"b")

c = torch.cat((x,y),dim = 2)
printt(c,"c")

d = torch.cat((x,y),dim = 3)
printt(d,"d")

文章目录

  • x:
  • y:
  • a = torch.cat((x,y),dim = 0)
  • b = torch.cat((x,y),dim = 1)
  • c = torch.cat((x,y),dim = 2)
  • d = torch.cat((x,y),dim = 3)

x:

tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11]],

         [[12, 13, 14, 15],
          [16, 17, 18, 19],
          [20, 21, 22, 23]]],


        [[[24, 25, 26, 27],
          [28, 29, 30, 31],
          [32, 33, 34, 35]],

         [[36, 37, 38, 39],
          [40, 41, 42, 43],
          [44, 45, 46, 47]]]])
x.dim4
x.shapetorch.Size([2, 2, 3, 4])

y:

tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]],


        [[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])
y.dim4
y.shapetorch.Size([2, 2, 3, 4])

a = torch.cat((x,y),dim = 0)

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],

         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]]],


        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.]],

         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]]],


        [[[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],

         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],


        [[[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],

         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
a.dim4
a.shapetorch.Size([4, 2, 3, 4])

b = torch.cat((x,y),dim = 1)

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.]],

         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.]],

         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],

         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],


        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.]],

         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]],

         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],

         [[ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
b.dim4
b.shapetorch.Size([2, 4, 3, 4])

c = torch.cat((x,y),dim = 2)

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],

         [[12., 13., 14., 15.],
          [16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]],


        [[[24., 25., 26., 27.],
          [28., 29., 30., 31.],
          [32., 33., 34., 35.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]],

         [[36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.],
          [ 1.,  1.,  1.,  1.]]]])
c.dim4
c.shapetorch.Size([2, 2, 6, 4])

d = torch.cat((x,y),dim = 3)

tensor([[[[ 0.,  1.,  2.,  3.,  1.,  1.,  1.,  1.],
          [ 4.,  5.,  6.,  7.,  1.,  1.,  1.,  1.],
          [ 8.,  9., 10., 11.,  1.,  1.,  1.,  1.]],

         [[12., 13., 14., 15.,  1.,  1.,  1.,  1.],
          [16., 17., 18., 19.,  1.,  1.,  1.,  1.],
          [20., 21., 22., 23.,  1.,  1.,  1.,  1.]]],


        [[[24., 25., 26., 27.,  1.,  1.,  1.,  1.],
          [28., 29., 30., 31.,  1.,  1.,  1.,  1.],
          [32., 33., 34., 35.,  1.,  1.,  1.,  1.]],

         [[36., 37., 38., 39.,  1.,  1.,  1.,  1.],
          [40., 41., 42., 43.,  1.,  1.,  1.,  1.],
          [44., 45., 46., 47.,  1.,  1.,  1.,  1.]]]])
d.dim4
d.shapetorch.Size([2, 2, 3, 8])

你可能感兴趣的:(#,Pytorch数据集Tools,pytorch,深度学习,python,人工智能,程序人生)