举例:
vector = torch.tensor([[1,2,3,4,5],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1],[1,1,1,1,1]])
a = vector[0, :]
weight = torch.arange(225).view(5, 5, 3, 3)
out = torch.einsum('oihw,o->oihw', weight, vector[0,:])
print(out)
# 等于在两个 O 维度上矩阵相乘
weight:
tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[ 12, 13, 14],
[ 15, 16, 17]],
[[ 18, 19, 20],
[ 21, 22, 23],
[ 24, 25, 26]],
[[ 27, 28, 29],
[ 30, 31, 32],
[ 33, 34, 35]],
[[ 36, 37, 38],
[ 39, 40, 41],
[ 42, 43, 44]]],
[[[ 45, 46, 47],
[ 48, 49, 50],
[ 51, 52, 53]],
[[ 54, 55, 56],
[ 57, 58, 59],
[ 60, 61, 62]],
[[ 63, 64, 65],
[ 66, 67, 68],
[ 69, 70, 71]],
[[ 72, 73, 74],
[ 75, 76, 77],
[ 78, 79, 80]],
[[ 81, 82, 83],
[ 84, 85, 86],
[ 87, 88, 89]]],
[[[ 90, 91, 92],
[ 93, 94, 95],
[ 96, 97, 98]],
[[ 99, 100, 101],
[102, 103, 104],
[105, 106, 107]],
[[108, 109, 110],
[111, 112, 113],
[114, 115, 116]],
[[117, 118, 119],
[120, 121, 122],
[123, 124, 125]],
[[126, 127, 128],
[129, 130, 131],
[132, 133, 134]]],
[[[135, 136, 137],
[138, 139, 140],
[141, 142, 143]],
[[144, 145, 146],
[147, 148, 149],
[150, 151, 152]],
[[153, 154, 155],
[156, 157, 158],
[159, 160, 161]],
[[162, 163, 164],
[165, 166, 167],
[168, 169, 170]],
[[171, 172, 173],
[174, 175, 176],
[177, 178, 179]]],
[[[180, 181, 182],
[183, 184, 185],
[186, 187, 188]],
[[189, 190, 191],
[192, 193, 194],
[195, 196, 197]],
[[198, 199, 200],
[201, 202, 203],
[204, 205, 206]],
[[207, 208, 209],
[210, 211, 212],
[213, 214, 215]],
[[216, 217, 218],
[219, 220, 221],
[222, 223, 224]]]])
out:
tensor([[[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[ 12, 13, 14],
[ 15, 16, 17]],
[[ 18, 19, 20],
[ 21, 22, 23],
[ 24, 25, 26]],
[[ 27, 28, 29],
[ 30, 31, 32],
[ 33, 34, 35]],
[[ 36, 37, 38],
[ 39, 40, 41],
[ 42, 43, 44]]],
[[[ 90, 92, 94],
[ 96, 98, 100],
[ 102, 104, 106]],
[[ 108, 110, 112],
[ 114, 116, 118],
[ 120, 122, 124]],
[[ 126, 128, 130],
[ 132, 134, 136],
[ 138, 140, 142]],
[[ 144, 146, 148],
[ 150, 152, 154],
[ 156, 158, 160]],
[[ 162, 164, 166],
[ 168, 170, 172],
[ 174, 176, 178]]],
[[[ 270, 273, 276],
[ 279, 282, 285],
[ 288, 291, 294]],
[[ 297, 300, 303],
[ 306, 309, 312],
[ 315, 318, 321]],
[[ 324, 327, 330],
[ 333, 336, 339],
[ 342, 345, 348]],
[[ 351, 354, 357],
[ 360, 363, 366],
[ 369, 372, 375]],
[[ 378, 381, 384],
[ 387, 390, 393],
[ 396, 399, 402]]],
[[[ 540, 544, 548],
[ 552, 556, 560],
[ 564, 568, 572]],
[[ 576, 580, 584],
[ 588, 592, 596],
[ 600, 604, 608]],
[[ 612, 616, 620],
[ 624, 628, 632],
[ 636, 640, 644]],
[[ 648, 652, 656],
[ 660, 664, 668],
[ 672, 676, 680]],
[[ 684, 688, 692],
[ 696, 700, 704],
[ 708, 712, 716]]],
[[[ 900, 905, 910],
[ 915, 920, 925],
[ 930, 935, 940]],
[[ 945, 950, 955],
[ 960, 965, 970],
[ 975, 980, 985]],
[[ 990, 995, 1000],
[1005, 1010, 1015],
[1020, 1025, 1030]],
[[1035, 1040, 1045],
[1050, 1055, 1060],
[1065, 1070, 1075]],
[[1080, 1085, 1090],
[1095, 1100, 1105],
[1110, 1115, 1120]]]])
#oi,hw->oihw
weight = torch.arange(25).view(5, 5, 1, 1)
weight0 = weight.squeeze(3)
weight1 = weight0.squeeze(2)
b = torch.tensor([[1,2,3],[1,2,2],[1,1,1]])
out = torch.einsum('oi,hw->oihw', weight1, b)
print(out)
weight:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
b:
tensor([[1, 2, 3],
[1, 2, 2],
[1, 1, 1]])
out:
tensor([[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[ 1, 2, 3],
[ 1, 2, 2],
[ 1, 1, 1]],
[[ 2, 4, 6],
[ 2, 4, 4],
[ 2, 2, 2]],
[[ 3, 6, 9],
[ 3, 6, 6],
[ 3, 3, 3]],
[[ 4, 8, 12],
[ 4, 8, 8],
[ 4, 4, 4]]],
[[[ 5, 10, 15],
[ 5, 10, 10],
[ 5, 5, 5]],
[[ 6, 12, 18],
[ 6, 12, 12],
[ 6, 6, 6]],
[[ 7, 14, 21],
[ 7, 14, 14],
[ 7, 7, 7]],
[[ 8, 16, 24],
[ 8, 16, 16],
[ 8, 8, 8]],
[[ 9, 18, 27],
[ 9, 18, 18],
[ 9, 9, 9]]],
[[[10, 20, 30],
[10, 20, 20],
[10, 10, 10]],
[[11, 22, 33],
[11, 22, 22],
[11, 11, 11]],
[[12, 24, 36],
[12, 24, 24],
[12, 12, 12]],
[[13, 26, 39],
[13, 26, 26],
[13, 13, 13]],
[[14, 28, 42],
[14, 28, 28],
[14, 14, 14]]],
[[[15, 30, 45],
[15, 30, 30],
[15, 15, 15]],
[[16, 32, 48],
[16, 32, 32],
[16, 16, 16]],
[[17, 34, 51],
[17, 34, 34],
[17, 17, 17]],
[[18, 36, 54],
[18, 36, 36],
[18, 18, 18]],
[[19, 38, 57],
[19, 38, 38],
[19, 19, 19]]],
[[[20, 40, 60],
[20, 40, 40],
[20, 20, 20]],
[[21, 42, 63],
[21, 42, 42],
[21, 21, 21]],
[[22, 44, 66],
[22, 44, 44],
[22, 22, 22]],
[[23, 46, 69],
[23, 46, 46],
[23, 23, 23]],
[[24, 48, 72],
[24, 48, 48],
[24, 24, 24]]]])
#oi,ohw->oihw
weight = torch.arange(25).view(5, 5, 1, 1)
weight0 = weight.squeeze(3)
weight1 = weight0.squeeze(2)
#b = torch.tensor([[1,2,3],[1,2,2],[1,1,1],[1,1,1],[1,1,1]])
c = torch.ones(5,3,3)
c[0,:,:] = torch.tensor([1,2,3])
out = torch.einsum('oi,ohw->oihw', weight1, c)
print(out)
weight1:
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
c:
tensor([[[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]])
out:
tensor([[[[ 0., 0., 0.],
[ 0., 0., 0.],
[ 0., 0., 0.]],
[[ 1., 2., 3.],
[ 1., 2., 3.],
[ 1., 2., 3.]],
[[ 2., 4., 6.],
[ 2., 4., 6.],
[ 2., 4., 6.]],
[[ 3., 6., 9.],
[ 3., 6., 9.],
[ 3., 6., 9.]],
[[ 4., 8., 12.],
[ 4., 8., 12.],
[ 4., 8., 12.]]],
[[[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]],
[[ 6., 6., 6.],
[ 6., 6., 6.],
[ 6., 6., 6.]],
[[ 7., 7., 7.],
[ 7., 7., 7.],
[ 7., 7., 7.]],
[[ 8., 8., 8.],
[ 8., 8., 8.],
[ 8., 8., 8.]],
[[ 9., 9., 9.],
[ 9., 9., 9.],
[ 9., 9., 9.]]],
[[[10., 10., 10.],
[10., 10., 10.],
[10., 10., 10.]],
[[11., 11., 11.],
[11., 11., 11.],
[11., 11., 11.]],
[[12., 12., 12.],
[12., 12., 12.],
[12., 12., 12.]],
[[13., 13., 13.],
[13., 13., 13.],
[13., 13., 13.]],
[[14., 14., 14.],
[14., 14., 14.],
[14., 14., 14.]]],
[[[15., 15., 15.],
[15., 15., 15.],
[15., 15., 15.]],
[[16., 16., 16.],
[16., 16., 16.],
[16., 16., 16.]],
[[17., 17., 17.],
[17., 17., 17.],
[17., 17., 17.]],
[[18., 18., 18.],
[18., 18., 18.],
[18., 18., 18.]],
[[19., 19., 19.],
[19., 19., 19.],
[19., 19., 19.]]],
[[[20., 20., 20.],
[20., 20., 20.],
[20., 20., 20.]],
[[21., 21., 21.],
[21., 21., 21.],
[21., 21., 21.]],
[[22., 22., 22.],
[22., 22., 22.],
[22., 22., 22.]],
[[23., 23., 23.],
[23., 23., 23.],
[23., 23., 23.]],
[[24., 24., 24.],
[24., 24., 24.],
[24., 24., 24.]]]])