欢迎关注我的微信公号:小张Python
einsum
全称 Einstein summation convention(爱因斯坦求和约定),用简单的方式来代表多维数组运算;
矩阵求各元素之和
A = ∑ i = 0 n ∑ j = 0 n a i a j A = \sum_{i=0}^n\sum_{j=0}^na_ia_j A=i=0∑nj=0∑naiaj
如果用 einsum函数可表示为
A = np.einsum('i',a)
多矩阵相乘
C i k = ∑ j = 1 n A i j B j k C_{ik} = \sum_{j=1}^nA_{ij}B_{jk} Cik=j=1∑nAijBjk
表示为
C_i = np.einsum('ij,jk',a,b)
求矩阵的迹
t r a c e = ∑ i = 1 n A i i trace = \sum_{i=1}^nA_{ii} trace=i=1∑nAii
用 einsum函数可表示为
trace = np.einsum('ii',a)
对于数组间运算例如矩阵乘积和、沿某一轴点积和;enisum 为多维数组运算提供另外一种表示方式,指定的下标标签列表,用逗号隔开;enisum
函数在 numpy ,Pytorch,Tensorflow 都有实现,使用方式如下
np.einsum(subscripts:str,operands:list of array_like)
函数中参数 subscripts 为字符串类型,表示运算命令,例如 "ii","ij,jk"
参数 operands 代表需要计算的数组或数组列表;
enisum 中 subscrpts 参数的字符串形式有两种方式:
1,implicit(隐式模式)
不包含->
标识符和输出标签;输出数组会根据选择的下标顺序进行排序,例如 np.einsum('ij',a)
得到的二维数组无变化,但 np.einsum('ji',a)
需要对输出数组进行转置 (i ,j 轴互换
)操作
2,explicit(显式模式)
包含 标识符->
及输出标签,能够增加函数的灵活性,例如调用 np.einsum('i->',a)
效果类似于 np.sum(a,axis = -1)
;而 np.einsum('ii->i',a)
等同于 np.diag(a)
;另外在显式模式中,会直接指定输出数组下标顺序,例如 np.einsum('ij,jh->ih',a,b)
表示矩阵相乘;目前下面运算都可用 enisum
函数表示;
numpy.trace
;numpy.diag
numpy.sum
numpy.transpose
numpy.matmul,numpy.dot
numpy.inner,numpy.outer
numpy.multiply
numpy.tensorbot
关于 np.enisum()
函数举几个栗子:
>>> a = np.arange(25).reshape(5,5)
>>> b = np.arange(5)
>>> c = np.arange(6).reshape(2,3)
1,计算矩阵的迹
>>> np.einsum('ii...->...i',a)
array([-0.796318 , 0.08363816, -0.79171551, 0.36235911])
>>> np.einsum('i...i',a)
-1.1420362461348776
>>> a
array([[-0.796318 , 1.54759498, -0.744291 , 0.02107445],
[ 0.03826498, 0.08363816, 0.92709203, 0.04769788],
[ 0.39088153, -0.85566069, -0.79171551, -1.50750047],
[-1.16165527, 0.77327936, 0.44133708, 0.36235911]])
>>> np.trace(a)
-1.1420362461348776
2,矩阵相乘
>>> b =np.random.rand(4,5)
>>> np.einsum('ij...,jk...->ik...',a,b)
array([[-0.207485 , 0.37929742, -1.14191507, -0.30398675, -0.59431733],
[ 0.03389816, 0.10101184, 0.78917293, 0.41502013, 0.37634113],
[-0.3499671 , -0.45043889, -1.67700784, -0.68798725, -0.49966522],
[-0.2186732 , -0.30529867, -0.21863002, 0.02300085, -0.50540297]])
>>> np.matmul(a,b)
array([[-0.207485 , 0.37929742, -1.14191507, -0.30398675, -0.59431733],
[ 0.03389816, 0.10101184, 0.78917293, 0.41502013, 0.37634113],
[-0.3499671 , -0.45043889, -1.67700784, -0.68798725, -0.49966522],
[-0.2186732 , -0.30529867, -0.21863002, 0.02300085, -0.50540297]])
3,对角线 diag
>>> np.einsum('ii->i',a)
array([ 0, 6, 12, 18, 24])
>>> np.diag(a)
array([ 0, 6, 12, 18, 24])
4,沿数组某一轴求和(需要在 explicit 模式下运行)
>>> # 沿着某一轴求和,显示模式运行
>>> np.einsum('ij->i',a)
array([ 10, 35, 60, 85, 110])
d
>>> np.sum(a,axis = 1)
array([ 10, 35, 60, 85, 110])
5,数组转置,改变轴顺序
>>> # 计算数组转置,对某些轴重新排序
>>> c
array([[0, 1, 2],
[3, 4, 5]])
>>> np.einsum('ji',c)
array([[0, 3],
[1, 4],
[2, 5]])
>>> np.transpose(c)
array([[0, 3],
[1, 4],
[2, 5]])
6,inner(计算内积)
>>> np.einsum('i,i',b,b)
30
>>> b
array([0, 1, 2, 3, 4])
>>> np.inner(b,b)# 一维数组,逐像素乘积和
30
7,矩阵点乘
>>> np.einsum('ij,j',a,b)
array([ 30, 80, 130, 180, 230])
>>> np.dot(a,b)
array([ 30, 80, 130, 180, 230])
>>> a
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
>>> b
array([0, 1, 2, 3, 4])
>>> np.dot(a[-1],b)# Test
230
8,张量数组相乘
>>> np.einsum(',ij',3,c)
array([[ 0, 3, 6],
[ 9, 12, 15]])
>>> np.multiply(3,c)
array([[ 0, 3, 6],
[ 9, 12, 15]])
>>> # 向量外积
8,广播机制,outer
>>> b
array([0, 1, 2, 3, 4])
>>> np.einsum('i,j',np.arange(2)+1,b)
array([[0, 1, 2, 3, 4],
[0, 2, 4, 6, 8]])
>>> np.outer(np.arange(2)+1,b)
array([[0, 1, 2, 3, 4],
[0, 2, 4, 6, 8]])
9,Tensor contraction,数组收缩沿着某一轴相乘
>>> a = np.arange(60).reshape(3,4,5)
>>> b = np.arange(24).reshape(4,3,2)
>>> np.einsum('ijk,jil->kl',a,b)
array([[4400, 4730],
[4532, 4874],
[4664, 5018],
[4796, 5162],
[4928, 5306]])
>>> np.tensordot(a,b,axes = [[1,0],[0,1]])
array([[4400, 4730],
[4532, 4874],
[4664, 5018],
[4796, 5162],
[4928, 5306]])
注:在 numpy 版本 1.12.0 之后,einsum
加入了 optimize
参数,用来优化 contraction
操作,对于 contraction
运算部分,操作的数组包含三个或三个以上,optimize
参数设置能提高计算效率,减小内存占比;
Reference:
1,https://zhuanlan.zhihu.com/p/71639781
2,https://numpy.org/doc/stable/reference/generated/numpy.einsum.html