Einsum是一种domain-specific language 域指定语言,可以高性能的实现诸如:dot products, outer product, 矩阵转置 transposes,矩阵向量乘法以及矩阵矩阵之间的乘法。其在numpy、tensorflow、pytorch等框架中可以大大提高编程效率以及矩阵运算的速度,代码也更加简洁。
Einstein summation 即爱因斯坦求和在numpy中的实现 einsum in numpy,在TensorFlow中的实现einsum in tensorflow以及在PyTorch中的实现einsum in PyTorch.相关Einstein summation解释的博客,OlexaAlex.
Einstein summation可以很好的替代TensorFlow或PyTorch中的dot products, outer products, transposes 以及 matrix-vector 或 matrix-matrix乘法函数实现,即可以简洁的实现上述常用计算操作,实现高效编码。einsum是一种domain-specific language 域指定语言,可以实现高性能的代码实现。einsum的语法基于PyTorch中的ensor Comprehensions,可以自动产生GPU代码以及微调代码来适应特定大小的输入。opt insum和tf einsum opt可用于优化tensor收缩顺序。
NumPy中的实现:np.einsum; PyTorch实现:torch.einsum; TensorFlow中的实现:tf.einsum,三个函数的签名均为einsum(equation, operands), equation为爱因斯坦求的字符串表达,operands 操作数为操作的目标tensor。eg:
c j = ∑ i ∑ k A i k B k j {\color{green}c_j} = \sum_i\sum_k {\color{red}A_{ik}}{\color{blue}B_{kj}} cj=i∑k∑AikBkj
可以用equation string(方程式字符串)表示为:
i k , k j − > j {\color{red}ik},{\color{blue}kj}->{\color{green}j} ik,kj−>j
note:其中i,j,k索引的指示可以为任意字母,但必须持续保持。
爱因斯坦求和还可以用来计算神经网络中任意计算图并且可以反向传播。公式可概括为下式:
其中双引号中方块是指向tensor具体维度的占位符,可以看出arg1、arg3为2-D tensor,即二维矩阵,arg3为3-D tensor,结果result为2-D matric。einsum()函数的输入可以为任意个不同维度的输入。
B j i = A i j {\color{green}B_{ji}} = {\color{red}A_{ij}} Bji=Aij
import torch
a = torch.arange(6).reshape(2,3)
a_ = torch.einsum('ij->ji', [a]) #此操作不会更改a的值,即该操作是对a的副本进行。
>>> a = torch.arange(6).reshape(2,3)
>>> a
tensor([[0, 1, 2],
[3, 4, 5]])
>>> torch.einsum("ij->ji", a)
tensor([[0, 3],
[1, 4],
[2, 5]])
>>> a
tensor([[0, 1, 2],
[3, 4, 5]])
b = ∑ i ∑ j A i j = A i j {\color{green}b} = \sum_i\sum_j {\color{red}A_{ij}} = {\color{red}A_{ij}} b=i∑j∑Aij=Aij
a = torch.arange(6).reshape(2,3)
a_ = torch.einsum("ij->", a) #求和操作,结果为一个整数值
>>> a_ = torch.einsum("ij->",a)
>>> a_
tensor(15)
>>> a_.size()
torch.Size([])
b j = ∑ i A i j = A i j {\color{green}b_j} = \sum_i{\color{red}A_{ij}} = {\color{red}A_{ij}} bj=i∑Aij=Aij
>>> a_ = torch.einsum("ij->j", a) #保持列维度
>>> a_
tensor([3, 5, 7])
b i = ∑ j A i j = A i j {\color{green}b_i} = \sum_j{\color{red}A_{ij}} = {\color{red}A_{ij}} bi=j∑Aij=Aij
>>> a_ = torch.einsum('ij -> i', a)
>>> a_
tensor([ 3, 12])
c i = ∑ k A i k b k {\color{green}c_i} = \sum_k {\color{red} A_{ik}}{\color{blue} b_k} ci=k∑Aikbk
>>> a
tensor([[0, 1, 2],
[3, 4, 5]])
>>> b = torch.arange(3)
>>> b
tensor([0, 1, 2])
>>> a_ = torch.einsum('ij,j -> i', [a, b])
>>> a_
tensor([ 5, 14])
c i j = ∑ k A i k B k j = A i k B k j {\color{green}c_{ij}} = \sum_k {\color{red}A_{ik}}{\color{blue}B_{kj}} = {\color{red}A_{ik}}{\color{blue}B_{kj}} cij=k∑AikBkj=AikBkj
>>> a
tensor([[0, 1, 2],
[3, 4, 5]])
>>> b
tensor([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14]])
>>> a_ = torch.einsum('ik,kj -> ij', [a, b])
>>> a_
tensor([[ 25, 28, 31, 34, 37],
[ 70, 82, 94, 106, 118]])
(vector 向量点乘)
c = ∑ i a i b i = a i b i {\color{green}c} = \sum_i {\color{red}a_i\color{blue}b_i} = {\color{red}a_i\color{blue}b_i} c=i∑aibi=aibi
a = torch.arange(3)
b = torch.arange(3,6) # -- a vector of length 3 containing [3, 4, 5]
a_ = torch.einsum('i,i->', [a, b])
>>> a_
tensor(14)
(Matrix 矩阵点乘)
c = ∑ i ∑ j A i j B i j = A i j B i j {\color{green}c} = \sum_i\sum_j {\color{red}A_{ij}\color{blue}B_{ij}} = {\color{red}A_{ij}\color{blue}B_{ij}} c=i∑j∑AijBij=AijBij
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6,12).reshape(2, 3)
a_ = torch.einsum('ij,ij->', [a, b])
>>> a_
tensor(145)
C i j = A i j B i j {\color{green}C_{ij}} = {\color{red}A_{ij}}{\color{blue}B_{ij}} Cij=AijBij
a = torch.arange(6).reshape(2,3)
>>> a
tensor([[0, 1, 2],
[3, 4, 5]])
b = torch.arange(6, 12).reshape(2,3)
tensor([[ 6, 7, 8],
[ 9, 10, 11]])
a_ = torch.einsum('ij,ij->ij', [a,b])
>>> a_
tensor([[ 0, 7, 16],
[27, 40, 55]])
a = torch.randn([4, 3, 224, 224]) #bcwh
b = torch.randn([4, 3, 224, 224]) #bcwh
a_ = torch.einsum('bcwh -> bc', [a])
## a_ 的替换写法
a_ = torch.zeros(a.shape[0], a.shape[1])
for vi in range(a.shape[0]):
for vj in range(a.shape[1]):
a_[vi][vj] = torch.sum(a[vi,vj,...]) #equal to torch.sum(a[vi,vj,:,:])
a_ = torch.einsum('bcwh, bcwh -> bc', [pc, tc])
## a_的替换写法
temp = F.mul(a, b)
intersection_ = torch.zeros(a.shape[0], b.shape[1])
for vi in range(a.shape[0]):
for vj in range(a.shape[1]):
intersection_ = torch.sum(temp[vi, vj,:,:])
a_ = torch.einsum('bc -> b', [intersection_])
a_ = torch.sum(intersection_, 1) #1 for cols, 0 for rows
>>> a_.size()
torch.Size([4])
C i j = a i b j {\color{green}C_{ij}} = {\color{red}a_i\color{blue}b_j} Cij=aibj
note:Outer Product 指张量积;Exterior Product是指解析几何中的外积Exterior product
>>> a = torch.arange(3)
>>> b = torch.arange(3,7)
>>> b
tensor([3, 4, 5, 6])
>>> a_ = torch.einsum('i, j -> ij', [a,b])
>>> a_
tensor([[ 0, 0, 0, 0],
[ 3, 4, 5, 6],
[ 6, 8, 10, 12]])
C i j l = ∑ k A i j k B i k l = A i j k B i k l {\color{green}C_{ijl}} = \sum_k{\color{red}A_{ijk}\color{blue}B_{ikl}} = {\color{red}A_{ijk}\color{blue}B_{ikl}} Cijl=k∑AijkBikl=AijkBikl
note:遵循矩阵乘法原则
A 3 x 2 x 5 B 3 x 5 x 3 = C 3 x 2 x 3 A_{3x{\color{red}{2x5}}}B_{3x{\color{red}{5x3}}}=C_{3x{\color{red}{2x3}}} A3x2x5B3x5x3=C3x2x3
>>> a = torch.randn(3,2,5)
>>> a
tensor([[[ 1.0289, -0.1355, -0.4099, -1.4501, -0.5968],
[-0.1667, -0.0286, 2.4426, -1.6169, -1.4544]],
[[ 0.6575, -0.4081, -0.3920, -1.5732, -0.2900],
[ 0.8656, -0.1987, 0.6598, 0.6242, 0.8018]],
[[ 0.5719, 0.0117, 0.6081, 0.0559, -1.6444],
[ 0.6912, -1.0747, -0.6326, -0.7700, 0.4683]]])
>>> b = torch.randn(3,5,3)
>>> b
tensor([[[-0.8491, 0.2271, 0.3570],
[-2.1355, -0.2560, 1.1238],
[ 2.1473, -0.7215, -0.1438],
[ 1.0483, 0.0798, 0.6447],
[-0.0853, 1.1740, -0.9754]],
[[ 0.0629, 0.5650, 0.6519],
[ 0.3167, -0.1258, -0.3802],
[-1.8143, -1.5648, 0.8874],
[-0.6165, 0.7251, 0.5943],
[ 1.5743, -1.2373, -2.3647]],
[[-1.9334, -3.4358, 0.7900],
[ 1.5289, 0.7921, -0.7569],
[ 1.5598, 0.1017, -0.2741],
[-2.2164, -1.4216, 0.6508],
[-0.3690, -0.0338, -1.3386]]])
>>> a_ = torch.einsum('ijk, ikl -> ijl', [a, b]) #遵循矩阵乘法原则
>>> a_
tensor([[[-2.9337, -0.2524, -0.0787],
[ 3.8770, -3.6294, -0.0668]],
[[ 1.1368, 0.2542, -0.0134],
[-0.3282, -1.0579, -0.2997]],
[[ 0.3435, -1.9180, 2.5137],
[-2.4323, -2.2114, 0.4050]]])
C p s t u v = ∑ q ∑ r A p q r s B t u q v r = A p q r s B t u q v r {\color{green}C_{pstuv}} = \sum_q\sum_r{\color{red}A_{pqrs}\color{blue}B_{tuqvr}} = {\color{red}A_{pqrs}\color{blue}B_{tuqvr}} Cpstuv=q∑r∑ApqrsBtuqvr=ApqrsBtuqvr
note:10中的Batch Matrix Multiplicationui是 Tensor Contraction的一种特例。栗子:n-D 维度有序的tensor: A ( I 1 , . . . , I n ) {\color{red}{A(I_1,...,I_n)}} A(I1,...,In), m-D 维度有序的tensor: B ( J 1 , . . . , J m ) {\color{blue}{B(J_1,...,J_m)}} B(J1,...,Jm), Assume n=4,m=5, 且: I 2 = J 3 a n d I 3 = I 5 I_2=J_3 \quad and \quad I_3=I_5 I2=J3andI3=I5, 则可以对A矩阵中的第2和第3,对B矩阵中的第3和第5维度进行张量A和张量B的tensor multiply张量乘法操作(张量收缩),最后结果可以得到如下的张量: C ( I 1 × I 4 × J 1 × J 2 × J 4 ) {\color{green}{C({I_1}\times{I_4}\times{J_1}\times{J_2}\times{J_4})}} C(I1×I4×J1×J2×J4)
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5)
torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
>>>torch.Size([2, 7, 11, 13, 17])
D i j = ∑ k ∑ l A i k B j k l C i l = A i k B j k l C i l {\color{green}D_{ij}} = \sum_k\sum_l{\color{red}A_{ik}}{\color{purple}B_{jkl}}{\color{blue}C_{il}} = {\color{red}A_{ik}}{\color{purple}B_{jkl}}{\color{blue}C_{il}} Dij=k∑l∑AikBjklCil=AikBjklCil
note:einsum可以操作多于两个的tensors,bilinear transformation 双线性变换就是一个例子。
torch.nn.Bilinear(in1_features, in2_features, out_features, bias=True)
Parameters:
in1_features – size of each first input sample
in2_features – size of each second input sample
out_features – size of each output sample
bias – If set to False, the layer will not learn an additive bias. Default: True
>>> a = torch.randn(2,3)
>>> a
tensor([[ 0.6125, 0.2316, -0.6014],
[-0.8615, -0.9863, 1.2960]])
>>> b = torch.randn(5, 3, 7)
>>> b
tensor([[[-1.1676, 0.0660, -0.8437, -0.1520, -2.5223, -0.2943, -1.2686],
[-0.6428, 1.5299, -1.0441, 1.1833, 0.3706, 0.6328, 0.6845],
[-0.0108, -0.6111, -0.8319, -0.6406, -0.9033, -0.7832, -0.1204]],
[[-1.1418, 0.6051, 0.2700, 0.4359, 0.5383, 1.5829, 0.0230],
[ 1.4027, -0.3302, 1.2980, 0.4281, -1.2563, -1.8918, -0.7306],
[ 0.2820, 0.0810, -0.4398, -0.3574, -0.8132, -0.1044, 0.1898]],
[[-1.5765, 0.0495, -1.1407, -0.3947, 0.2474, 1.1779, -0.7999],
[ 1.5090, 0.3881, -0.0344, 0.6265, 0.0975, 0.3283, -0.2570],
[ 0.3752, 0.3056, -1.5868, -0.0883, -1.6226, -0.5940, 1.6271]],
[[ 1.1354, -0.3212, 0.4247, 0.0681, 0.9360, -0.3384, 1.4519],
[-0.1396, -0.8205, -0.1776, 0.4195, 0.2457, -0.9302, 1.4448],
[-0.7819, -0.9069, -1.4125, -0.3920, -1.0380, 2.0351, 0.6595]],
[[-1.2359, 0.3135, -0.1942, 0.2450, -1.0833, -0.1889, 0.2017],
[ 0.1240, -1.3459, -1.1797, 0.2006, 1.5969, 0.4993, 1.2332],
[ 0.1330, 0.2703, 0.5384, 2.3118, -0.3098, -0.0144, 0.7328]]])
>>> c = torch.randn(2,7)
>>> c
tensor([[-0.1731, 1.9290, 0.1928, 0.6679, 0.7227, -1.8295, 0.8021],
[-1.0124, 1.5828, 1.6196, -0.1516, 0.9401, 0.2667, -1.2518]])
>>> a_ = torch.einsum('ij,kjn, in -> ik', [a, b, c])
>>> a_.size()
torch.Size([2, 5])
einsum对于TreeQN中公式(6)的实现如下:
给定 l l l层的低维状态表达 z l z_l zl和对于每一个action a对应的转移(激活)函数 W a W^a Wa, 使用residual connection残差链接来计算下一层 l + 1 l+1 l+1层的状态表达: z l + 1 a {z_{l+1}}^a zl+1a
z l + 1 a = z l + tanh ( W a z l ) \mathbf{z}^a_{l+1} = \mathbf{z}_l + \tanh(\mathbf{W}^a\mathbf{z}_l) zl+1a=zl+tanh(Wazl)
实践中,想要高效的在同一时间对于所有的转移函数(例如:所有的动作A)计算批大小为 B B B中的 K K K维状态表达: Z ∈ R B × K Z\in\mathbb R^{B\,\times\,K} Z∈RB×K,可以将这些转移函数放在张量: W ∈ R A × K × K W\in \mathbb R^{A\,\times \,K\,\times\, K} W∈RA×K×K中并且计算高效的使用einsum来实现下一个状态表达。
import torch.nn.functional as F
def random_tensors(shape, num=1, requires_grad=False):
tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)]
return tensors[0] if num == 1 else tensors
# Parameters
# -- [num_actions x hidden_dimension], b is the parameter which need to be learned, so with gradient, W as the same
b = random_tensors([5, 3], requires_grad=True)
# -- [num_actions x hidden_dimension x hidden_dimension]
W = random_tensors([5, 3, 3], requires_grad=True)
def transition(zl):
# zl.unsqueeze(1) : convert zl to shape [2, 1, 3]
# einsum return a tensor of shape [2, 5, 3], b shape is [5, 3] , '+' operation is broadcastable
# return shape [2, 5, 3] -- [batch_size x num_actions x hidden_dimension]
return zl.unsqueeze(1) + F.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b)
# Sampled dummy inputs
# -- [batch_size x hidden_dimension]
zl = random_tensors([2, 3])
zl_ = transition(zl).shape
>>>torch.Size([2, 5, 3])
tensor([[[-2.4080, -3.5068, 0.5025],
[-2.1035, -1.5754, 2.4678],
[-2.9639, -3.5726, 2.4945],
[-2.8011, -1.9178, 0.4953],
[-0.9763, -3.5736, 0.5021]],
[[ 2.0257, -2.7589, 0.1339],
[ 0.9365, -0.7853, -1.8133],
[ 0.6122, -2.7831, 0.1818],
[ 2.3801, -2.4952, -1.8072],
[ 1.7467, -2.7832, 0.1259]]], grad_fn=)
博客搬家