Python的torch.einsum计算方法

涉及以下内容

简述

  • 爱因斯坦求和约定(einsum)简洁优雅,定义元素间的乘法与加法计算规则实现向量内积,向量外积,矩阵乘法,转置和张量收缩等张量操作

例如

a = torch.rand(3, 4)
b = torch.rand(4, 5)
c = torch.einsum("ik,kj->ij", a, b)
# einsum 的第一个参数是字符串,"ik,kj->ij" 描述张量的计算规则,表示 a 的 i 行 k 列与 b 的 k 列 j 行对应元素相乘再相加输出作为结果的第 i 行 j 列元素,维度的索引号只能是26个英文字母 'a' - 'z' 之一
# einsum 的参数中仅一个变量时可认为元素相加前与 1 相乘

# einsum 的第一个参数可以不写包括箭头在内的右边部分,比如矩阵乘法 "ik,kj" 等价于 "ik,kj->ij" 输出保留输入只出现一次的索引且索引按字母表顺序排列
# einsum 的第一个参数支持 "..." 省略号,例如一个 5 维张量可以用 "..." 表示计算中不关心维度的索引号

# einsum 的其他参数是变量名,如 a, b 表示实际的输入张量列表,真实维度需匹配规则
# 索引顺序可以任意,但 "ik,kj->ij" 如果写成 "ik,kj->ji" 后一将返回前一的转置 

c ( i , j ) = ∑ k a ( i , k ) ⋅ b ( k , j ) c_{(i,j)}=\sum_{k}a_{(i,k)}\cdot b_{(k,j)} c(i,j)=ka(i,k)b(k,j)

  • 自由索引,同时出现在箭头(等式)左边与右边的索引,比如上述的 i 和 j,表示输入元素与输出元素的位置关系
  • 求和索引,只出现在箭头左边(等式右边)的索引,比如上述的 k,表示输入元素在此维度上相乘相加后作为输出元素

实践

import torch
import numpy as np

# 1:矩阵乘法
a = torch.rand(2, 3)
b = torch.rand(3, 4)

ein_out = torch.einsum("ik,kj->ij", a, b).numpy()  # ein_out = torch.einsum("ik,kj", a, b).numpy()
org_out = torch.mm(a, b).numpy()

print("input:\n", a, b,  sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

# 2:矩阵点乘
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6, 12).reshape(2, 3)

ein_out = torch.einsum('ij,ij->ij', a, b).numpy()
org_out = torch.mul(a, b).numpy()

print("input:\n", a, b,  sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

# 3:张量后两维乘法
a = torch.randn(2, 3, 5)
b = torch.randn(2, 5, 3)

ein_out = torch.einsum('ijk,ikl->ijl', a, b).numpy()
org_out = torch.matmul(a, b).numpy()  # org_out = torch.bmm(a, b).numpy()  # batch矩阵乘法

print("input:\n", a, b,  sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

# 4:矩阵转置
a = torch.arange(6).reshape(2, 3)

ein_out = torch.einsum('ij->ji', a).numpy()
org_out = torch.transpose(a, 0, 1).numpy()

print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

# 5:张量后两维转置
a = torch.randn(1, 2, 3, 4, 5)

ein_out = torch.einsum('...ij->...ji', a).numpy()
org_out = a.permute(0, 1, 2, 4, 3).numpy()

print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

# 6:矩阵求和
a = torch.arange(6).reshape(2, 3)

ein_out = torch.einsum('ij->', a).numpy()
org_out = torch.sum(a).numpy()

ein_out_i = torch.einsum('ij->i', a).numpy()
org_out_i = torch.sum(a, dim=1).numpy()

ein_out_j = torch.einsum('ij->j', a).numpy()
org_out_j = torch.sum(a, dim=0).numpy()

print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

print("input:\n", a)
print("ein_out_i: \n", ein_out_i)
print("org_out_i: \n", org_out_i)
print("is org_out_i == ein_out_i ?", np.allclose(ein_out, org_out))

print("input:\n", a)
print("ein_out_j: \n", ein_out_j)
print("org_out_J: \n", org_out_j)
print("is org_out_j == ein_out_j ?", np.allclose(ein_out, org_out))

# 7:矩阵提取对角线元素
a = torch.arange(9).reshape(3, 3)

ein_out = torch.einsum('ii->i', a).numpy()
org_out = torch.diagonal(a, 0).numpy()

print("input:\n", a)
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

# 8:矩阵向量乘法
a = torch.rand(3, 4)
b = torch.arange(4.0)

ein_out = torch.einsum('ik,k->i', a, b).numpy()  # ein_out_k = torch.einsum('ik,k', [a, b]).numpy()
org_out = torch.mv(a, b).numpy()

print("input:\n", a, b,  sep='\n')
print("ein_out_k: \n", ein_out)
print("org_out_k: \n", org_out)
print("is org_out_k == ein_out_k ?", np.allclose(ein_out, org_out))

# 9:向量内积
a = torch.arange(3)
b = torch.arange(3, 6)

ein_out = torch.einsum('i,i->', a, b).numpy()  # ein_out = torch.einsum('i,i', [a, b]).numpy()
org_out = torch.dot(a, b).numpy()

print("input:\n", a, b,  sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

# 10:向量外积
a = torch.arange(3)
b = torch.arange(3, 5)

ein_out = torch.einsum('i,j->ij', a, b).numpy()  # ein_out = torch.einsum('i,j', [a, b]).numpy()
org_out = torch.outer(a, b).numpy()

print("input:\n", a, b,  sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

# 11:张量收缩
a = torch.randn(1, 3, 5, 7)
b = torch.randn(11, 33, 3, 55, 5)

ein_out = torch.einsum('pqrs,tuqvr->pstuv', a, b).numpy()
org_out = torch.tensordot(a, b, dims=([1, 2], [2, 4])).numpy()

print("input:\n", a, b,  sep='\n')
print("ein_out: \n", ein_out)
print("org_out: \n", org_out)
print("is org_out == ein_out ?", np.allclose(ein_out, org_out))

你可能感兴趣的:(Python,python,pytorch,线性代数)