PyTorch——通俗理解torch.einsum

参考链接

  1. https://www.cnblogs.com/mengnan/p/10319701.html

从einsum表达式恢复为张量计算

torch.einsum一般这样用:

r e s u l t = t o r c h . e i n s u m ( ′ xx,xx,xx → xx ′ , a r g 1 , a r g 2 , a r g 3 ) result=torch.einsum\left ( '\textrm{xx,xx,xx}\rightarrow \textrm{xx}',arg1,arg2,arg3\right ) result=torch.einsum(xx,xx,xxxx,arg1,arg2,arg3)

问题在于如何把里面的einsum表达式 ′ xx,xx,xx → xx ′ , a r g 1 , a r g 2 , a r g 3 '\textrm{xx,xx,xx}\rightarrow \textrm{xx}',arg1,arg2,arg3 xx,xx,xxxx,arg1,arg2,arg3,还原为直观的张量计算形式。

里面的 a r g arg arg为输入的张量, → \rightarrow 左边的 xx,xx,xx \textrm{xx,xx,xx} xx,xx,xx为输入张量的下标,如 ik,jkl,il \textrm{ik,jkl,il} ik,jkl,il,有多少个输入就有多少个下标, → \rightarrow 右边的 xx \textrm{xx} xx为输出张量的下标,有时候也空着。恢复步骤如下:

  1. 首先写成输出的模板,用 R R R表示输出张量。

R xx = a r g 1 xx ⋅ a r g 2 xx ⋅ a r g 3 xx R_{\textrm{xx}}=arg1_{\textrm{xx}}\cdot arg2_{\textrm{xx}}\cdot arg3_{\textrm{xx}} Rxx=arg1xxarg2xxarg3xx

  1. 将输入张量下标(即 → \rightarrow 左边的 xx,xx,xx \textrm{xx,xx,xx} xx,xx,xx)中没有在输出张量下标(即 → \rightarrow 右边的 xx \textrm{xx} xx)出现的字母,用求和符号加在等式右边。

R xx = ∑ ∑ a r g 1 xx ⋅ a r g 2 xx ⋅ a r g 3 xx R_{\textrm{xx}}=\sum \sum arg1_{\textrm{xx}}\cdot arg2_{\textrm{xx}}\cdot arg3_{\textrm{xx}} Rxx=arg1xxarg2xxarg3xx

实例

1. torch.einsum(‘ijk,ikl->ijl’,[a,b])

第一步:

R i j l = A i j k B i k l R_{ijl}=A_{ijk}B_{ikl} Rijl=AijkBikl

第二步:
k k k没有在输出张量下标中出现,所以:

R i j l = ∑ k A i j k B i k l R_{ijl}=\sum_{k}A_{ijk}B_{ikl} Rijl=kAijkBikl

2. torch.einsum(‘ij->i’, [a])

第一步:

R i = A i j R_i=A_{ij} Ri=Aij

第二步:
j j j没有在输出张量下标中出现,所以:

R i = ∑ j A i j R_i=\sum_jA_{ij} Ri=jAij

3. torch.einsum(‘ij->’,[a])

第一步:

R = A i j R=A_{ij} R=Aij

第二步:

R = ∑ i ∑ j A i j R=\sum_{i}\sum_{j}A_{ij} R=ijAij

你可能感兴趣的:(PyTorch,pytorch,einsum,字符串)