目录
1. 向量点乘(内积)
2. 外积
3. 矩阵和向量的乘法
4. 矩阵乘法
5. 批量矩阵乘法
6. 求和操作
torch.einsum
是 PyTorch 中的一个强大工具,它允许你通过 Einstein summation convention(爱因斯坦求和约定)来执行复杂的张量操作。使用这种约定,你可以用一个字符串来指定张量操作的维度规则。这个函数非常灵活,可以用于实现各种张量运算,如元素相乘、矩阵乘法、批量矩阵乘法、迹等。
torch.einsum
的语法格式如下:
torch.einsum(equation, *operands)
其中,equation
是一个字符串,指定了张量操作的约定;*operands
是要操作的张量。
以下是 torch.einsum
的一些常见用法:
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([0, 1, 0])
# 计算两个向量的内积
result = torch.einsum('i,i->', a, b) # 输出: tensor(2)
这里 'i,i->'
告诉 einsum
计算 a
和 b
的对应元素相乘的和,最后结果是一个标量。
# 计算两个向量的外积
result = torch.einsum('i,j->ij', a, b) # 输出: 2x3 矩阵
这里 'i,j->ij'
表示 a
的每个元素(索引为 i
)与 b
的每个元素(索引为 j
)相乘,得到一个二维矩阵。
A = torch.tensor([[1, 2], [0, 1], [2, 0]])
v = torch.tensor([0, 1])
# 矩阵和向量的乘法
result = torch.einsum('ij,j->i', A, v) # 输出: 1D tensor of size 3
这里 'ij,j->i'
表示 A
的每一行与向量 v
相乘,结果是一个向量。
A = torch.tensor([[1, 2], [0, 1]])
B = torch.tensor([[2, 0], [0, 2]])
# 矩阵乘法
result = torch.einsum('ik,kj->ij', A, B) # 输出: 2x2 矩阵
A = torch.randn(3, 2, 5)
B = torch.randn(3, 5, 3)
# 批量矩阵乘法 (Batched Matrix Multiplication)
result = torch.einsum('bik,bkj->bij', A, B) # 输出: 3x2x3 矩阵
这里 'bik,bkj->bij'
表示对于批量中的每个矩阵,进行矩阵乘法。
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 计算矩阵中所有元素的和
result = torch.einsum('ij->', A) # 输出: tensor(21)
这里 'ij->'
表示对 A
的所有元素进行求和。
torch.einsum
的强大之处在于,你可以通过正确地安排这些字母和箭头来执行非常复杂的操作。需要注意的是,einsum
操作的效率可能不如专门的函数(如 torch.matmul
对于矩阵乘法),但它提供了一种非常简洁和通用的方式来表达复杂的张量计算。