涉及以下内容
简述
例如
a = torch.rand(3, 4)
b = torch.rand(4, 5)
c = torch.einsum("ik,kj->ij", a, b)
# einsum 的第一个参数是字符串,"ik,kj->ij" 描述张量的计算规则,表示 a 的 i 行 k 列与 b 的 k 列 j 行对应元素相乘再相加输出作为结果的第 i 行 j 列元素,维度的索引号只能是26个英文字母 'a' - 'z' 之一
# einsum 的参数中仅一个变量时可认为元素相加前与 1 相乘
# einsum 的第一个参数可以不写包括箭头在内的右边部分,比如矩阵乘法 "ik,kj" 等价于 "ik,kj->ij" 输出保留输入只出现一次的索引且索引按字母表顺序排列
# einsum 的第一个参数支持 "..." 省略号,例如一个 5 维张量可以用 "..." 表示计算中不关心维度的索引号
# einsum 的其他参数是变量名,如 a, b 表示实际的输入张量列表,真实维度需匹配规则
# 索引顺序可以任意,但 "ik,kj->ij" 如果写成 "ik,kj->ji" 后一将返回前一的转置
c ( i , j ) = ∑ k a ( i , k ) ⋅ b ( k , j ) c_{(i,j)}=\sum_{k}a_{(i,k)}\cdot b_{(k,j)} c(i,j)=k∑a(i,k)⋅b(k,j)
实践
import torch
import numpy as np
# 1:矩阵乘法
a = torch.rand(2, 3)
b = torch.rand(3, 4)
ein_out = torch.einsum("ik,kj->ij", a, b).numpy() # ein_out = torch.einsum("ik,kj", a, b).numpy()
org_out = torch.mm(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 2:矩阵点乘
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6, 12).reshape(2, 3)
ein_out = torch.einsum('ij,ij->ij', a, b).numpy()
org_out = torch.mul(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 3:张量后两维乘法
a = torch.randn(2, 3, 5)
b = torch.randn(2, 5, 3)
ein_out = torch.einsum('ijk,ikl->ijl', a, b).numpy()
org_out = torch.matmul(a, b).numpy() # org_out = torch.bmm(a, b).numpy() # batch矩阵乘法
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 4:矩阵转置
a = torch.arange(6).reshape(2, 3)
ein_out = torch.einsum('ij->ji', a).numpy()
org_out = torch.transpose(a, 0, 1).numpy()
print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 5:张量后两维转置
a = torch.randn(1, 2, 3, 4, 5)
ein_out = torch.einsum('...ij->...ji', a).numpy()
org_out = a.permute(0, 1, 2, 4, 3).numpy()
print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 6:矩阵求和
a = torch.arange(6).reshape(2, 3)
ein_out = torch.einsum('ij->', a).numpy()
org_out = torch.sum(a).numpy()
ein_out_i = torch.einsum('ij->i', a).numpy()
org_out_i = torch.sum(a, dim=1).numpy()
ein_out_j = torch.einsum('ij->j', a).numpy()
org_out_j = torch.sum(a, dim=0).numpy()
print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
print("input:\n", a)
print("ein_out_i: \n", ein_out_i)
print("org_out_i: \n", org_out_i)
print("is org_out_i == ein_out_i ?", np.allclose(ein_out, org_out))
print("input:\n", a)
print("ein_out_j: \n", ein_out_j)
print("org_out_J: \n", org_out_j)
print("is org_out_j == ein_out_j ?", np.allclose(ein_out, org_out))
# 7:矩阵提取对角线元素
a = torch.arange(9).reshape(3, 3)
ein_out = torch.einsum('ii->i', a).numpy()
org_out = torch.diagonal(a, 0).numpy()
print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 8:矩阵向量乘法
a = torch.rand(3, 4)
b = torch.arange(4.0)
ein_out = torch.einsum('ik,k->i', a, b).numpy() # ein_out_k = torch.einsum('ik,k', [a, b]).numpy()
org_out = torch.mv(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out_k: \n", ein_out)
print("org_out_k: \n", org_out)
print("is org_out_k == ein_out_k ?", np.allclose(ein_out, org_out))
# 9:向量内积
a = torch.arange(3)
b = torch.arange(3, 6)
ein_out = torch.einsum('i,i->', a, b).numpy() # ein_out = torch.einsum('i,i', [a, b]).numpy()
org_out = torch.dot(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 10:向量外积
a = torch.arange(3)
b = torch.arange(3, 5)
ein_out = torch.einsum('i,j->ij', a, b).numpy() # ein_out = torch.einsum('i,j', [a, b]).numpy()
org_out = torch.outer(a, b).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))
# 11:张量收缩
a = torch.randn(1, 3, 5, 7)
b = torch.randn(11, 33, 3, 55, 5)
ein_out = torch.einsum('pqrs,tuqvr->pstuv', a, b).numpy()
org_out = torch.tensordot(a, b, dims=([1, 2], [2, 4])).numpy()
print("input:\n", a, b, sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))