torch.einsum详解

torch.einsum详解_第1张图片

 举例:

vector = torch.tensor([[1,2,3,4,5],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]])
a = vector[0, :]
weight = torch.arange(225).view(5, 5, 3, 3)
out = torch.einsum('oihw,o->oihw', weight, vector[0,:])
print(out)
# 等于在两个 O 维度上矩阵相乘

weight:
tensor([[[[  0,   1,   2],
          [  3,   4,   5],
          [  6,   7,   8]],

         [[  9,  10,  11],
          [ 12,  13,  14],
          [ 15,  16,  17]],

         [[ 18,  19,  20],
          [ 21,  22,  23],
          [ 24,  25,  26]],

         [[ 27,  28,  29],
          [ 30,  31,  32],
          [ 33,  34,  35]],

         [[ 36,  37,  38],
          [ 39,  40,  41],
          [ 42,  43,  44]]],


        [[[ 45,  46,  47],
          [ 48,  49,  50],
          [ 51,  52,  53]],

         [[ 54,  55,  56],
          [ 57,  58,  59],
          [ 60,  61,  62]],

         [[ 63,  64,  65],
          [ 66,  67,  68],
          [ 69,  70,  71]],

         [[ 72,  73,  74],
          [ 75,  76,  77],
          [ 78,  79,  80]],

         [[ 81,  82,  83],
          [ 84,  85,  86],
          [ 87,  88,  89]]],


        [[[ 90,  91,  92],
          [ 93,  94,  95],
          [ 96,  97,  98]],

         [[ 99, 100, 101],
          [102, 103, 104],
          [105, 106, 107]],

         [[108, 109, 110],
          [111, 112, 113],
          [114, 115, 116]],

         [[117, 118, 119],
          [120, 121, 122],
          [123, 124, 125]],

         [[126, 127, 128],
          [129, 130, 131],
          [132, 133, 134]]],


        [[[135, 136, 137],
          [138, 139, 140],
          [141, 142, 143]],

         [[144, 145, 146],
          [147, 148, 149],
          [150, 151, 152]],

         [[153, 154, 155],
          [156, 157, 158],
          [159, 160, 161]],

         [[162, 163, 164],
          [165, 166, 167],
          [168, 169, 170]],

         [[171, 172, 173],
          [174, 175, 176],
          [177, 178, 179]]],


        [[[180, 181, 182],
          [183, 184, 185],
          [186, 187, 188]],

         [[189, 190, 191],
          [192, 193, 194],
          [195, 196, 197]],

         [[198, 199, 200],
          [201, 202, 203],
          [204, 205, 206]],

         [[207, 208, 209],
          [210, 211, 212],
          [213, 214, 215]],

         [[216, 217, 218],
          [219, 220, 221],
          [222, 223, 224]]]])
out:
tensor([[[[   0,    1,    2],
          [   3,    4,    5],
          [   6,    7,    8]],

         [[   9,   10,   11],
          [  12,   13,   14],
          [  15,   16,   17]],

         [[  18,   19,   20],
          [  21,   22,   23],
          [  24,   25,   26]],

         [[  27,   28,   29],
          [  30,   31,   32],
          [  33,   34,   35]],

         [[  36,   37,   38],
          [  39,   40,   41],
          [  42,   43,   44]]],


        [[[  90,   92,   94],
          [  96,   98,  100],
          [ 102,  104,  106]],

         [[ 108,  110,  112],
          [ 114,  116,  118],
          [ 120,  122,  124]],

         [[ 126,  128,  130],
          [ 132,  134,  136],
          [ 138,  140,  142]],

         [[ 144,  146,  148],
          [ 150,  152,  154],
          [ 156,  158,  160]],

         [[ 162,  164,  166],
          [ 168,  170,  172],
          [ 174,  176,  178]]],


        [[[ 270,  273,  276],
          [ 279,  282,  285],
          [ 288,  291,  294]],

         [[ 297,  300,  303],
          [ 306,  309,  312],
          [ 315,  318,  321]],

         [[ 324,  327,  330],
          [ 333,  336,  339],
          [ 342,  345,  348]],

         [[ 351,  354,  357],
          [ 360,  363,  366],
          [ 369,  372,  375]],

         [[ 378,  381,  384],
          [ 387,  390,  393],
          [ 396,  399,  402]]],


        [[[ 540,  544,  548],
          [ 552,  556,  560],
          [ 564,  568,  572]],

         [[ 576,  580,  584],
          [ 588,  592,  596],
          [ 600,  604,  608]],

         [[ 612,  616,  620],
          [ 624,  628,  632],
          [ 636,  640,  644]],

         [[ 648,  652,  656],
          [ 660,  664,  668],
          [ 672,  676,  680]],

         [[ 684,  688,  692],
          [ 696,  700,  704],
          [ 708,  712,  716]]],


        [[[ 900,  905,  910],
          [ 915,  920,  925],
          [ 930,  935,  940]],

         [[ 945,  950,  955],
          [ 960,  965,  970],
          [ 975,  980,  985]],

         [[ 990,  995, 1000],
          [1005, 1010, 1015],
          [1020, 1025, 1030]],

         [[1035, 1040, 1045],
          [1050, 1055, 1060],
          [1065, 1070, 1075]],

         [[1080, 1085, 1090],
          [1095, 1100, 1105],
          [1110, 1115, 1120]]]])
#oi,hw->oihw
weight = torch.arange(25).view(5, 5, 1, 1)
weight0 = weight.squeeze(3)
weight1 = weight0.squeeze(2)
b = torch.tensor([[1,2,3],[1,2,2],[1,1,1]])
out = torch.einsum('oi,hw->oihw', weight1, b)
print(out)

weight:
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])
b:
tensor([[1, 2, 3],
        [1, 2, 2],
        [1, 1, 1]])
out:
tensor([[[[ 0,  0,  0],
          [ 0,  0,  0],
          [ 0,  0,  0]],

         [[ 1,  2,  3],
          [ 1,  2,  2],
          [ 1,  1,  1]],

         [[ 2,  4,  6],
          [ 2,  4,  4],
          [ 2,  2,  2]],

         [[ 3,  6,  9],
          [ 3,  6,  6],
          [ 3,  3,  3]],

         [[ 4,  8, 12],
          [ 4,  8,  8],
          [ 4,  4,  4]]],


        [[[ 5, 10, 15],
          [ 5, 10, 10],
          [ 5,  5,  5]],

         [[ 6, 12, 18],
          [ 6, 12, 12],
          [ 6,  6,  6]],

         [[ 7, 14, 21],
          [ 7, 14, 14],
          [ 7,  7,  7]],

         [[ 8, 16, 24],
          [ 8, 16, 16],
          [ 8,  8,  8]],

         [[ 9, 18, 27],
          [ 9, 18, 18],
          [ 9,  9,  9]]],


        [[[10, 20, 30],
          [10, 20, 20],
          [10, 10, 10]],

         [[11, 22, 33],
          [11, 22, 22],
          [11, 11, 11]],

         [[12, 24, 36],
          [12, 24, 24],
          [12, 12, 12]],

         [[13, 26, 39],
          [13, 26, 26],
          [13, 13, 13]],

         [[14, 28, 42],
          [14, 28, 28],
          [14, 14, 14]]],


        [[[15, 30, 45],
          [15, 30, 30],
          [15, 15, 15]],

         [[16, 32, 48],
          [16, 32, 32],
          [16, 16, 16]],

         [[17, 34, 51],
          [17, 34, 34],
          [17, 17, 17]],

         [[18, 36, 54],
          [18, 36, 36],
          [18, 18, 18]],

         [[19, 38, 57],
          [19, 38, 38],
          [19, 19, 19]]],


        [[[20, 40, 60],
          [20, 40, 40],
          [20, 20, 20]],

         [[21, 42, 63],
          [21, 42, 42],
          [21, 21, 21]],

         [[22, 44, 66],
          [22, 44, 44],
          [22, 22, 22]],

         [[23, 46, 69],
          [23, 46, 46],
          [23, 23, 23]],

         [[24, 48, 72],
          [24, 48, 48],
          [24, 24, 24]]]])
#oi,ohw->oihw
weight = torch.arange(25).view(5, 5, 1, 1)
weight0 = weight.squeeze(3)
weight1 = weight0.squeeze(2)
#b = torch.tensor([[1,2,3],[1,2,2],[1,1,1],[1,1,1],[1,1,1]])
c = torch.ones(5,3,3)
c[0,:,:] = torch.tensor([1,2,3])
out = torch.einsum('oi,ohw->oihw', weight1, c)

print(out)

weight1:
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])

c:
tensor([[[1., 2., 3.],
         [1., 2., 3.],
         [1., 2., 3.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])
out:
tensor([[[[ 0.,  0.,  0.],
          [ 0.,  0.,  0.],
          [ 0.,  0.,  0.]],

         [[ 1.,  2.,  3.],
          [ 1.,  2.,  3.],
          [ 1.,  2.,  3.]],

         [[ 2.,  4.,  6.],
          [ 2.,  4.,  6.],
          [ 2.,  4.,  6.]],

         [[ 3.,  6.,  9.],
          [ 3.,  6.,  9.],
          [ 3.,  6.,  9.]],

         [[ 4.,  8., 12.],
          [ 4.,  8., 12.],
          [ 4.,  8., 12.]]],


        [[[ 5.,  5.,  5.],
          [ 5.,  5.,  5.],
          [ 5.,  5.,  5.]],

         [[ 6.,  6.,  6.],
          [ 6.,  6.,  6.],
          [ 6.,  6.,  6.]],

         [[ 7.,  7.,  7.],
          [ 7.,  7.,  7.],
          [ 7.,  7.,  7.]],

         [[ 8.,  8.,  8.],
          [ 8.,  8.,  8.],
          [ 8.,  8.,  8.]],

         [[ 9.,  9.,  9.],
          [ 9.,  9.,  9.],
          [ 9.,  9.,  9.]]],


        [[[10., 10., 10.],
          [10., 10., 10.],
          [10., 10., 10.]],

         [[11., 11., 11.],
          [11., 11., 11.],
          [11., 11., 11.]],

         [[12., 12., 12.],
          [12., 12., 12.],
          [12., 12., 12.]],

         [[13., 13., 13.],
          [13., 13., 13.],
          [13., 13., 13.]],

         [[14., 14., 14.],
          [14., 14., 14.],
          [14., 14., 14.]]],


        [[[15., 15., 15.],
          [15., 15., 15.],
          [15., 15., 15.]],

         [[16., 16., 16.],
          [16., 16., 16.],
          [16., 16., 16.]],

         [[17., 17., 17.],
          [17., 17., 17.],
          [17., 17., 17.]],

         [[18., 18., 18.],
          [18., 18., 18.],
          [18., 18., 18.]],

         [[19., 19., 19.],
          [19., 19., 19.],
          [19., 19., 19.]]],


        [[[20., 20., 20.],
          [20., 20., 20.],
          [20., 20., 20.]],

         [[21., 21., 21.],
          [21., 21., 21.],
          [21., 21., 21.]],

         [[22., 22., 22.],
          [22., 22., 22.],
          [22., 22., 22.]],

         [[23., 23., 23.],
          [23., 23., 23.],
          [23., 23., 23.]],

         [[24., 24., 24.],
          [24., 24., 24.],
          [24., 24., 24.]]]])

你可能感兴趣的:(python,pytorch,深度学习)