Einsum in PyTorch✍️


    Einsum是一种domain-specific language 域指定语言,可以高性能的实现诸如:dot products, outer product, 矩阵转置 transposes,矩阵向量乘法以及矩阵矩阵之间的乘法。其在numpy、tensorflow、pytorch等框架中可以大大提高编程效率以及矩阵运算的速度,代码也更加简洁。


     Einstein summation 即爱因斯坦求和在numpy中的实现 einsum in numpy,在TensorFlow中的实现einsum in tensorflow以及在PyTorch中的实现einsum in PyTorch.相关Einstein summation解释的博客,OlexaAlex.

Einstein summation可以很好的替代TensorFlow或PyTorch中的dot products, outer products, transposes 以及 matrix-vector 或 matrix-matrix乘法函数实现,即可以简洁的实现上述常用计算操作,实现高效编码。einsum是一种domain-specific language 域指定语言,可以实现高性能的代码实现。einsum的语法基于PyTorch中的ensor Comprehensions,可以自动产生GPU代码以及微调代码来适应特定大小的输入。opt insum和tf einsum opt可用于优化tensor收缩顺序。


NumPy中的实现:np.einsum; PyTorch实现:torch.einsum; TensorFlow中的实现:tf.einsum,三个函数的签名均为einsum(equation, operands), equation为爱因斯坦求的字符串表达,operands 操作数为操作的目标tensor。eg:
c j = ∑ i ∑ k A i k B k j {\color{green}c_j} = \sum_i\sum_k {\color{red}A_{ik}}{\color{blue}B_{kj}} cj=ikAikBkj
可以用equation string(方程式字符串)表示为:
i k , k j − > j {\color{red}ik},{\color{blue}kj}->{\color{green}j} ik,kj>j

其中双引号中方块是指向tensor具体维度的占位符,可以看出arg1、arg3为2-D tensor,即二维矩阵,arg3为3-D tensor,结果result为2-D matric。einsum()函数的输入可以为任意个不同维度的输入。

eg: 1.Matrix Transpose 矩阵转置

B j i = A i j {\color{green}B_{ji}} = {\color{red}A_{ij}} Bji=Aij

import torch
a = torch.arange(6).reshape(2,3)
a_ = torch.einsum('ij->ji', [a]) #此操作不会更改a的值,即该操作是对a的副本进行。
>>> a = torch.arange(6).reshape(2,3)
>>> a
tensor([[0, 1, 2],
        [3, 4, 5]])
>>> torch.einsum("ij->ji", a)
tensor([[0, 3],
        [1, 4],
        [2, 5]])
>>> a
tensor([[0, 1, 2],
        [3, 4, 5]])

2.SUM 求和操作

b = ∑ i ∑ j A i j = A i j {\color{green}b} = \sum_i\sum_j {\color{red}A_{ij}} = {\color{red}A_{ij}} b=ijAij=Aij

a = torch.arange(6).reshape(2,3)
a_ = torch.einsum("ij->", a) #求和操作,结果为一个整数值
>>> a_ = torch.einsum("ij->",a)
>>> a_
>>> a_.size()

3.Column SUM 行求和

b j = ∑ i A i j = A i j {\color{green}b_j} = \sum_i{\color{red}A_{ij}} = {\color{red}A_{ij}} bj=iAij=Aij

>>> a_ = torch.einsum("ij->j", a) #保持列维度
>>> a_
tensor([3, 5, 7])

4. Row SUM 列求和

b i = ∑ j A i j = A i j {\color{green}b_i} = \sum_j{\color{red}A_{ij}} = {\color{red}A_{ij}} bi=jAij=Aij

>>> a_ = torch.einsum('ij -> i', a)
>>> a_
tensor([ 3, 12])

5. Matrix-Vector Multiplication 矩阵-向量 乘法

c i = ∑ k A i k b k {\color{green}c_i} = \sum_k {\color{red} A_{ik}}{\color{blue} b_k} ci=kAikbk

>>> a
tensor([[0, 1, 2],
        [3, 4, 5]])
>>> b = torch.arange(3)
>>> b
tensor([0, 1, 2])
>>> a_ = torch.einsum('ij,j -> i', [a, b])
>>> a_
tensor([ 5, 14])

6. Matrix-Matrix Multiplication 矩阵-矩阵 乘法

c i j = ∑ k A i k B k j = A i k B k j {\color{green}c_{ij}} = \sum_k {\color{red}A_{ik}}{\color{blue}B_{kj}} = {\color{red}A_{ik}}{\color{blue}B_{kj}} cij=kAikBkj=AikBkj

>>> a
tensor([[0, 1, 2],
        [3, 4, 5]])
>>> b
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])
>>> a_ = torch.einsum('ik,kj -> ij', [a, b])
>>> a_
tensor([[ 25,  28,  31,  34,  37],
        [ 70,  82,  94, 106, 118]])

7. Dot Product 点积操作

(vector 向量点乘)
c = ∑ i a i b i = a i b i {\color{green}c} = \sum_i {\color{red}a_i\color{blue}b_i} = {\color{red}a_i\color{blue}b_i} c=iaibi=aibi

a = torch.arange(3)
b = torch.arange(3,6)  # -- a vector of length 3 containing [3, 4, 5]
a_ = torch.einsum('i,i->', [a, b])
>>> a_

(Matrix 矩阵点乘)
c = ∑ i ∑ j A i j B i j = A i j B i j {\color{green}c} = \sum_i\sum_j {\color{red}A_{ij}\color{blue}B_{ij}} = {\color{red}A_{ij}\color{blue}B_{ij}} c=ijAijBij=AijBij

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
a_ = torch.einsum('ij,ij->', [a, b])
>>> a_

8. Hadamard Product 哈达马乘积

C i j = A i j B i j {\color{green}C_{ij}} = {\color{red}A_{ij}}{\color{blue}B_{ij}} Cij=AijBij

a = torch.arange(6).reshape(2,3)
>>> a
tensor([[0, 1, 2],
        [3, 4, 5]])
b = torch.arange(6, 12).reshape(2,3)
tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
a_ = torch.einsum('ij,ij->ij', [a,b])
>>> a_
tensor([[ 0,  7, 16],
        [27, 40, 55]])


a = torch.randn([4, 3, 224, 224]) #bcwh
b = torch.randn([4, 3, 224, 224]) #bcwh
a_ = torch.einsum('bcwh -> bc', [a])
## a_ 的替换写法
a_ = torch.zeros(a.shape[0], a.shape[1])
for vi in range(a.shape[0]):
    for vj in range(a.shape[1]):
        a_[vi][vj] = torch.sum(a[vi,vj,...]) #equal to torch.sum(a[vi,vj,:,:])
a_ = torch.einsum('bcwh, bcwh -> bc', [pc, tc])
## a_的替换写法
temp = F.mul(a, b)
intersection_ = torch.zeros(a.shape[0], b.shape[1])
for vi in range(a.shape[0]):
    for vj in range(a.shape[1]):
        intersection_ = torch.sum(temp[vi, vj,:,:])
a_ = torch.einsum('bc -> b', [intersection_])
a_ = torch.sum(intersection_, 1) #1 for cols, 0 for rows
>>> a_.size()

9. Outer Product 外积(线性代数)

C i j = a i b j {\color{green}C_{ij}} = {\color{red}a_i\color{blue}b_j} Cij=aibj
note:Outer Product 指张量积;Exterior Product是指解析几何中的外积Exterior product

>>> a = torch.arange(3)
>>> b = torch.arange(3,7)
>>> b
tensor([3, 4, 5, 6])
>>> a_ = torch.einsum('i, j -> ij', [a,b])
>>> a_
tensor([[ 0,  0,  0,  0],
        [ 3,  4,  5,  6],
        [ 6,  8, 10, 12]])

10. Batch Matrix Multiplication

C i j l = ∑ k A i j k B i k l = A i j k B i k l {\color{green}C_{ijl}} = \sum_k{\color{red}A_{ijk}\color{blue}B_{ikl}} = {\color{red}A_{ijk}\color{blue}B_{ikl}} Cijl=kAijkBikl=AijkBikl
A 3 x 2 x 5 B 3 x 5 x 3 = C 3 x 2 x 3 A_{3x{\color{red}{2x5}}}B_{3x{\color{red}{5x3}}}=C_{3x{\color{red}{2x3}}} A3x2x5B3x5x3=C3x2x3

>>> a = torch.randn(3,2,5)
>>> a
tensor([[[ 1.0289, -0.1355, -0.4099, -1.4501, -0.5968],
         [-0.1667, -0.0286,  2.4426, -1.6169, -1.4544]],

        [[ 0.6575, -0.4081, -0.3920, -1.5732, -0.2900],
         [ 0.8656, -0.1987,  0.6598,  0.6242,  0.8018]],

        [[ 0.5719,  0.0117,  0.6081,  0.0559, -1.6444],
         [ 0.6912, -1.0747, -0.6326, -0.7700,  0.4683]]])
>>> b = torch.randn(3,5,3)
>>> b
tensor([[[-0.8491,  0.2271,  0.3570],
         [-2.1355, -0.2560,  1.1238],
         [ 2.1473, -0.7215, -0.1438],
         [ 1.0483,  0.0798,  0.6447],
         [-0.0853,  1.1740, -0.9754]],

        [[ 0.0629,  0.5650,  0.6519],
         [ 0.3167, -0.1258, -0.3802],
         [-1.8143, -1.5648,  0.8874],
         [-0.6165,  0.7251,  0.5943],
         [ 1.5743, -1.2373, -2.3647]],

        [[-1.9334, -3.4358,  0.7900],
         [ 1.5289,  0.7921, -0.7569],
         [ 1.5598,  0.1017, -0.2741],
         [-2.2164, -1.4216,  0.6508],
         [-0.3690, -0.0338, -1.3386]]])
>>> a_ = torch.einsum('ijk, ikl -> ijl', [a, b]) #遵循矩阵乘法原则
>>> a_
tensor([[[-2.9337, -0.2524, -0.0787],
         [ 3.8770, -3.6294, -0.0668]],

        [[ 1.1368,  0.2542, -0.0134],
         [-0.3282, -1.0579, -0.2997]],

        [[ 0.3435, -1.9180,  2.5137],
         [-2.4323, -2.2114,  0.4050]]])

11. Tensor Contraction 张量收缩 (多维矩阵相乘)

C p s t u v = ∑ q ∑ r A p q r s B t u q v r = A p q r s B t u q v r {\color{green}C_{pstuv}} = \sum_q\sum_r{\color{red}A_{pqrs}\color{blue}B_{tuqvr}} = {\color{red}A_{pqrs}\color{blue}B_{tuqvr}} Cpstuv=qrApqrsBtuqvr=ApqrsBtuqvr
note:10中的Batch Matrix Multiplicationui是 Tensor Contraction的一种特例。栗子:n-D 维度有序的tensor: A ( I 1 , . . . , I n ) {\color{red}{A(I_1,...,I_n)}} A(I1,...,In), m-D 维度有序的tensor: B ( J 1 , . . . , J m ) {\color{blue}{B(J_1,...,J_m)}} B(J1,...,Jm), Assume n=4,m=5, 且: I 2 = J 3 a n d I 3 = I 5 I_2=J_3 \quad and \quad I_3=I_5 I2=J3andI3=I5, 则可以对A矩阵中的第2和第3,对B矩阵中的第3和第5维度进行张量A和张量B的tensor multiply张量乘法操作(张量收缩),最后结果可以得到如下的张量: C ( I 1 × I 4 × J 1 × J 2 × J 4 ) {\color{green}{C({I_1}\times{I_4}\times{J_1}\times{J_2}\times{J_4})}} C(I1×I4×J1×J2×J4)

a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
>>>torch.Size([2, 7, 11, 13, 17])

12. Bilinear Transformation 双线性变换\[y=x_1Ax_2)\]

D i j = ∑ k ∑ l A i k B j k l C i l = A i k B j k l C i l {\color{green}D_{ij}} = \sum_k\sum_l{\color{red}A_{ik}}{\color{purple}B_{jkl}}{\color{blue}C_{il}} = {\color{red}A_{ik}}{\color{purple}B_{jkl}}{\color{blue}C_{il}} Dij=klAikBjklCil=AikBjklCil

note:einsum可以操作多于两个的tensors,bilinear transformation 双线性变换就是一个例子。

torch.nn.Bilinear(in1_features, in2_features, out_features, bias=True)

in1_features – size of each first input sample

in2_features – size of each second input sample

out_features – size of each output sample

bias – If set to False, the layer will not learn an additive bias. Default: True

>>> a = torch.randn(2,3)
>>> a
tensor([[ 0.6125,  0.2316, -0.6014],
        [-0.8615, -0.9863,  1.2960]])
>>> b = torch.randn(5, 3, 7)
>>> b
tensor([[[-1.1676,  0.0660, -0.8437, -0.1520, -2.5223, -0.2943, -1.2686],
         [-0.6428,  1.5299, -1.0441,  1.1833,  0.3706,  0.6328,  0.6845],
         [-0.0108, -0.6111, -0.8319, -0.6406, -0.9033, -0.7832, -0.1204]],

        [[-1.1418,  0.6051,  0.2700,  0.4359,  0.5383,  1.5829,  0.0230],
         [ 1.4027, -0.3302,  1.2980,  0.4281, -1.2563, -1.8918, -0.7306],
         [ 0.2820,  0.0810, -0.4398, -0.3574, -0.8132, -0.1044,  0.1898]],

        [[-1.5765,  0.0495, -1.1407, -0.3947,  0.2474,  1.1779, -0.7999],
         [ 1.5090,  0.3881, -0.0344,  0.6265,  0.0975,  0.3283, -0.2570],
         [ 0.3752,  0.3056, -1.5868, -0.0883, -1.6226, -0.5940,  1.6271]],

        [[ 1.1354, -0.3212,  0.4247,  0.0681,  0.9360, -0.3384,  1.4519],
         [-0.1396, -0.8205, -0.1776,  0.4195,  0.2457, -0.9302,  1.4448],
         [-0.7819, -0.9069, -1.4125, -0.3920, -1.0380,  2.0351,  0.6595]],

        [[-1.2359,  0.3135, -0.1942,  0.2450, -1.0833, -0.1889,  0.2017],
         [ 0.1240, -1.3459, -1.1797,  0.2006,  1.5969,  0.4993,  1.2332],
         [ 0.1330,  0.2703,  0.5384,  2.3118, -0.3098, -0.0144,  0.7328]]])
>>> c = torch.randn(2,7)
>>> c
tensor([[-0.1731,  1.9290,  0.1928,  0.6679,  0.7227, -1.8295,  0.8021],
        [-1.0124,  1.5828,  1.6196, -0.1516,  0.9401,  0.2667, -1.2518]])
>>> a_ = torch.einsum('ij,kjn, in -> ik', [a, b, c])
>>> a_.size()
torch.Size([2, 5])

13. TreeQN

给定 l l l层的低维状态表达 z l z_l zl和对于每一个action a对应的转移(激活)函数 W a W^a Wa, 使用residual connection残差链接来计算下一层 l + 1 l+1 l+1层的状态表达: z l + 1 a {z_{l+1}}^a zl+1a
z l + 1 a = z l + tanh ⁡ ( W a z l ) \mathbf{z}^a_{l+1} = \mathbf{z}_l + \tanh(\mathbf{W}^a\mathbf{z}_l) zl+1a=zl+tanh(Wazl)
实践中,想要高效的在同一时间对于所有的转移函数(例如:所有的动作A)计算批大小为 B B B中的 K K K维状态表达: Z ∈ R B   ×   K Z\in\mathbb R^{B\,\times\,K} ZRB×K,可以将这些转移函数放在张量: W ∈ R A   ×   K   ×   K W\in \mathbb R^{A\,\times \,K\,\times\, K} WRA×K×K中并且计算高效的使用einsum来实现下一个状态表达。

import torch.nn.functional as F

def random_tensors(shape, num=1, requires_grad=False):
  tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)]
  return tensors[0] if num == 1 else tensors

# Parameters
# -- [num_actions x hidden_dimension], b is the parameter which need to be learned, so with gradient, W as the same
b = random_tensors([5, 3], requires_grad=True)
# -- [num_actions x hidden_dimension x hidden_dimension]
W = random_tensors([5, 3, 3], requires_grad=True)

def transition(zl):
  # zl.unsqueeze(1) : convert zl to shape [2, 1, 3]
  # einsum return a tensor of shape [2, 5, 3], b shape is [5, 3] , '+' operation  is broadcastable
  # return shape [2, 5, 3] -- [batch_size x num_actions x hidden_dimension]
  return zl.unsqueeze(1) + F.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b)

# Sampled dummy inputs
# -- [batch_size x hidden_dimension]
zl = random_tensors([2, 3])

zl_ = transition(zl).shape
>>>torch.Size([2, 5, 3])
tensor([[[-2.4080, -3.5068,  0.5025],
         [-2.1035, -1.5754,  2.4678],
         [-2.9639, -3.5726,  2.4945],
         [-2.8011, -1.9178,  0.4953],
         [-0.9763, -3.5736,  0.5021]],

        [[ 2.0257, -2.7589,  0.1339],
         [ 0.9365, -0.7853, -1.8133],
         [ 0.6122, -2.7831,  0.1818],
         [ 2.3801, -2.4952, -1.8072],
         [ 1.7467, -2.7832,  0.1259]]], grad_fn=)

