torch.einsum一般这样用:
r e s u l t = t o r c h . e i n s u m ( ′ xx,xx,xx → xx ′ , a r g 1 , a r g 2 , a r g 3 ) result=torch.einsum\left ( '\textrm{xx,xx,xx}\rightarrow \textrm{xx}',arg1,arg2,arg3\right ) result=torch.einsum(′xx,xx,xx→xx′,arg1,arg2,arg3)
问题在于如何把里面的einsum表达式 ′ xx,xx,xx → xx ′ , a r g 1 , a r g 2 , a r g 3 '\textrm{xx,xx,xx}\rightarrow \textrm{xx}',arg1,arg2,arg3 ′xx,xx,xx→xx′,arg1,arg2,arg3,还原为直观的张量计算形式。
里面的 a r g arg arg为输入的张量, → \rightarrow →左边的 xx,xx,xx \textrm{xx,xx,xx} xx,xx,xx为输入张量的下标,如 ik,jkl,il \textrm{ik,jkl,il} ik,jkl,il,有多少个输入就有多少个下标, → \rightarrow →右边的 xx \textrm{xx} xx为输出张量的下标,有时候也空着。恢复步骤如下:
R xx = a r g 1 xx ⋅ a r g 2 xx ⋅ a r g 3 xx R_{\textrm{xx}}=arg1_{\textrm{xx}}\cdot arg2_{\textrm{xx}}\cdot arg3_{\textrm{xx}} Rxx=arg1xx⋅arg2xx⋅arg3xx
R xx = ∑ ∑ a r g 1 xx ⋅ a r g 2 xx ⋅ a r g 3 xx R_{\textrm{xx}}=\sum \sum arg1_{\textrm{xx}}\cdot arg2_{\textrm{xx}}\cdot arg3_{\textrm{xx}} Rxx=∑∑arg1xx⋅arg2xx⋅arg3xx
第一步:
R i j l = A i j k B i k l R_{ijl}=A_{ijk}B_{ikl} Rijl=AijkBikl
第二步:
k k k没有在输出张量下标中出现,所以:
R i j l = ∑ k A i j k B i k l R_{ijl}=\sum_{k}A_{ijk}B_{ikl} Rijl=∑kAijkBikl
第一步:
R i = A i j R_i=A_{ij} Ri=Aij
第二步:
j j j没有在输出张量下标中出现,所以:
R i = ∑ j A i j R_i=\sum_jA_{ij} Ri=∑jAij
第一步:
R = A i j R=A_{ij} R=Aij
第二步:
R = ∑ i ∑ j A i j R=\sum_{i}\sum_{j}A_{ij} R=∑i∑jAij