最近在准备做 HW04,在读 transformer 的源码的时候发现 attention score 的 torch.matmul() 的奇妙设置,故有此篇文章进行分享。
前言碎碎念:
一开始我以为 torch.matmul 所做的工作就是简单的矩阵相乘,即:假设我们有两个矩阵
A
和B
,它们的 size 分别为(m, n)
和(n, p)
,那么 A x B 的 size 为 (m, p)。然后我看了眼官方文档的例子:tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(10, 4, 5) torch.matmul(tensor1, tensor2).size()
>> torch.Size([10, 3, 5])
大大的问号冒了出来 : ),这也能乘?
文章的代码文件:notebook 代码
torch
.matmul
(input, other, ***, out=None) → Tensor两个张量的矩阵乘积,具体行为取决于张量的维度,如下所示。
这里为了描述方便,用 input_d
和 other_d
分别指代 input.dim()
和 other.dim()
,使用 torch.randint()
替代 torch.randn()
方便印证。
import torch
# 固定 torch 的随机数种子,以便重现结果
torch.manual_seed(0)
# 打印信息
def print_info(A, B):
print(f"A: {A}\nB: {B}")
print(f"A 的维度: {A.dim()},\t B 的维度: {B.dim()}")
print(f"A 的元素总数: {A.numel()},\t B 的元素总数: {B.numel()}")
print(f"torch.matmul(A, B): {torch.matmul(A, B)}")
print(f"torch.matmul(A, B).size(): {torch.matmul(A, B).size()}")
此时就是我们常说的点积(dot product)
,返回标量。注意,这里是维度为 1,而不是元素总数。
A = torch.randint(0, 5, size=(2,))
B = torch.randint(0, 5, size=(2,))
print_info(A, B)
>> A: tensor([4, 4])
>> B: tensor([3, 0])
>> A 的维度: 1, B 的维度: 1
>> A 的元素总数: 2, B 的元素总数: 2
>> torch.matmul(A, B) = 12
>> torch.matmul(A, B).size() = torch.Size([])
返回矩阵乘积的结果。
A = torch.randint(0, 5, size=(2, 1))
B = torch.randint(0, 5, size=(1, 2))
print_info(A, B)
>> A: tensor([[3],
>> [4]])
>> B: tensor([[2, 3]])
>> A 的维度: 2, B 的维度: 2
>> A 的元素总数: 2, B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[ 6, 9],
>> [ 8, 12]])
>> torch.matmul(A, B).size() = torch.Size([2, 2])
按照广播机制(boardcasting)进行处理,即:从 size 的尾部开始一一比对,如果维度不够,则扩展一维,令初始值为 1 再进行计算。计算完之后移除扩展的维度,用下面的例子来说就是扩展成 (1, 2) 后,(1, 2) * (2, 2) => (1, 2) => (2, )
A = torch.randint(0, 5, size=(2, ))
B = torch.randint(0, 5, size=(2, 2))
print_info(A, B)
>> A: tensor([2, 3])
>> B: tensor([[1, 1],
>> [1, 4]])
>> A 的维度: 1, B 的维度: 2
>> A 的元素总数: 2, B 的元素总数: 4
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])
返回矩阵与向量的乘积。
# 这里使用上一次的矩阵和向量,方便对照
print_info(B, A)
>> A: tensor([[1, 1],
>> [1, 4]])
>> B: tensor([2, 3])
>> A 的维度: 2, B 的维度: 1
>> A 的元素总数: 4, B 的元素总数: 2
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])
以 input_d > 2 为例,维度不匹配就通过广播机制扩展,最后结果上删除掉扩展的维度。
个人理解:对于 dim >= 2 的 tensor 来说最后两维被看作矩阵的行和列,其余(如果存在)被看作 batch。
对于非矩阵(non-matrix)维度也是进行广播处理的,以 A.size() = (j, 1, m, n) 和 B.size() =(k, n, m) 为例,j x 1 和 k 是非矩阵维度,也就是 batch 维度,torch.matmul(A, B).size() = (j, k, m, m)。
矩阵部分:(1, 2) * (2, 1)
A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, 1))
print_info(A, B)
>> A: tensor([[[3, 1]],
>>
>> [[1, 3]]])
>> B: tensor([[4],
>> [3]])
>> A 的维度: 3, B 的维度: 2
>> A 的元素总数: 4, B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[[15]],
>>
>> [[13]]])
这里可以看成单拎出 A 的最后 2 维与 B 做 input_d = 2 和 other_d = 1 的乘法:(1, 2) * (2, ),具体细节可以回看上面对应的部分。
A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, ))
print_info(A, B)
>> A: tensor([[[1, 4]],
>>
>> [[1, 4]]])
>> B: tensor([4, 1])
>> A 的维度: 3, B 的维度: 1
>> A 的元素总数: 4, B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[8],
>> [8]])
>> torch.matmul(A, B).size() = torch.Size([2, 1])
广播部分:(2, 1, *, *) => (2, 2, *, *)。矩阵部分:(2, 1) * (1, 2)
A = torch.randint(0, 5, size=(2, 1, 2, 1))
B = torch.randint(0, 5, size=(2, 1, 2))
print_info(A, B)
>> A: tensor([[[[4],
>> [4]]],
>>
>>
>> [[[4],
>> [0]]]])
>> B: tensor([[[1, 2]],
>>
>> [[3, 0]]])
>> A 的维度: 4, B 的维度: 3
>> A 的元素总数: 4, B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[ 4, 8],
>> [ 4, 8]],
>>
>> [[12, 0],
>> [12, 0]]],
>>
>>
>> [[[ 4, 8],
>> [ 0, 0]],
>>
>> [[12, 0],
>> [ 0, 0]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 2, 2])
在往下翻之前不妨思考一下 torch.matmul(B, A).size() 等于多少。
print_info(B, A)
>> A: tensor([[[1, 2]],
>>
>> [[3, 0]]])
>> B: tensor([[[[4],
>> [4]]],
>>
>>
>> [[[4],
>> [0]]]])
>> A 的维度: 3, B 的维度: 4
>> A 的元素总数: 4, B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[12]],
>>
>> [[12]]],
>>
>>
>> [[[ 4]],
>>
>> [[12]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 1, 1])
Broadcasting