tf.einsum(
equation, *inputs, **kwargs
)
用于实现tensorflow中张量的点积、外积、装置、矩阵乘法等操作,einsum是一个表达以上这些运算,包括复杂张量运算在内的优雅方式,基本上,可以把einsum看成一种领域特定语言。一旦你理解并能利用einsum,除了不用记忆和频繁查找特定库函数这个好处以外,你还能够更迅速地编写更加紧凑、高效的代码。
比如,两个张量A、B相乘,得到矩阵C,用公式表示如下:
ij,jk->ik
参数equation表示不同操作对应的einsum记法,是一个字符串;参数*inputs表示输入的多个张量,形状要和equation中的对应。
1、矩阵点积,c = sum_i a[i]*b[i]
>>> a = tf.constant([1, 2, 3, 4])
>>> b = tf.constant([2, 2, 2, 2])
>>> tf.einsum("i,i->", a, b)
2、矩阵乘法,c[i, k] = sum_j a[i, j]*b[j, k]
>>> a = tf.ones([2, 2])
>>> b = tf.ones([2, 2])
>>> tf.einsum("ij,jk->ik", a, b)
3、矩阵转置, b[j, i] = a[i, j]
>>> a = tf.constant([[1, 2], [3, 4]])
>>>> a
>>> tf.einsum("ij->ji", a)
4、获取矩阵的对角元素
>>> a = tf.linalg.band_part(tf.ones((3, 3)), 0, 0)
>>> a
>>> tf.einsum("ii->i", a)