torch.matmul() 详解

最近在准备做 HW04,在读 transformer 的源码的时候发现 attention score 的 torch.matmul() 的奇妙设置,故有此篇文章进行分享。

前言碎碎念:

一开始我以为 torch.matmul 所做的工作就是简单的矩阵相乘,即:假设我们有两个矩阵 AB,它们的 size 分别为 (m, n)(n, p),那么 A x B 的 size 为 (m, p)。然后我看了眼官方文档的例子:

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
>> torch.Size([10, 3, 5])

大大的问号冒了出来 : ),这也能乘?

文章的代码文件:notebook 代码

文章目录

  • 前期工作
  • input_d = other_d = 1(两个 Tensor 皆为 1 维)
  • input_d = other_d = 2 (两个 Tensor 皆为 2 维)
  • input_d = 1, other_d = 2
  • input_d = 2, other_d = 1
  • input_d > 2 or other_d > 2
    • input_d > 2 and other_d = 2
    • input_d > 2 and other_d = 1
    • input_d > 2 and other_d >2 (多维 Tensor)
  • 拓展阅读

下面结合官方文档提供一些例子给大家理解。

torch.matmul(input, other, ***, out=None) → Tensor

两个张量的矩阵乘积,具体行为取决于张量的维度,如下所示。

这里为了描述方便,用 input_dother_d 分别指代 input.dim()other.dim(),使用 torch.randint() 替代 torch.randn() 方便印证。

前期工作

import torch

# 固定 torch 的随机数种子,以便重现结果
torch.manual_seed(0)

# 打印信息
def print_info(A, B):
    print(f"A: {A}\nB: {B}")
    print(f"A 的维度: {A.dim()},\t B 的维度: {B.dim()}")
    print(f"A 的元素总数: {A.numel()},\t B 的元素总数: {B.numel()}")
    print(f"torch.matmul(A, B): {torch.matmul(A, B)}")
    print(f"torch.matmul(A, B).size(): {torch.matmul(A, B).size()}")
    

input_d = other_d = 1(两个 Tensor 皆为 1 维)

此时就是我们常说的点积(dot product),返回标量。注意,这里是维度为 1,而不是元素总数。

A = torch.randint(0, 5, size=(2,))
B = torch.randint(0, 5, size=(2,))

print_info(A, B)
>> A: tensor([4, 4])
>> B: tensor([3, 0])
>> A 的维度: 1,	 B 的维度: 1
>> A 的元素总数: 2,	 B 的元素总数: 2
>> torch.matmul(A, B) = 12
>> torch.matmul(A, B).size() = torch.Size([])

input_d = other_d = 2 (两个 Tensor 皆为 2 维)

返回矩阵乘积的结果。

A = torch.randint(0, 5, size=(2, 1))
B = torch.randint(0, 5, size=(1, 2))

print_info(A, B)
>> A: tensor([[3],
>>         [4]])
>> B: tensor([[2, 3]])
>> A 的维度: 2,	 B 的维度: 2
>> A 的元素总数: 2,	 B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[ 6,  9],
>>         [ 8, 12]])
>> torch.matmul(A, B).size() = torch.Size([2, 2])

input_d = 1, other_d = 2

按照广播机制(boardcasting)进行处理,即:从 size 的尾部开始一一比对,如果维度不够,则扩展一维,令初始值为 1 再进行计算。计算完之后移除扩展的维度,用下面的例子来说就是扩展成 (1, 2) 后,(1, 2) * (2, 2) => (1, 2) => (2, )

A = torch.randint(0, 5, size=(2, ))
B = torch.randint(0, 5, size=(2, 2))

print_info(A, B)
>> A: tensor([2, 3])
>> B: tensor([[1, 1],
>>         [1, 4]])
>> A 的维度: 1,	 B 的维度: 2
>> A 的元素总数: 2,	 B 的元素总数: 4
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])

input_d = 2, other_d = 1

返回矩阵与向量的乘积。

# 这里使用上一次的矩阵和向量,方便对照
print_info(B, A)
>> A: tensor([[1, 1],
>>         [1, 4]])
>> B: tensor([2, 3])
>> A 的维度: 2,	 B 的维度: 1
>> A 的元素总数: 4,	 B 的元素总数: 2
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])

input_d > 2 or other_d > 2

以 input_d > 2 为例,维度不匹配就通过广播机制扩展,最后结果上删除掉扩展的维度。

个人理解:对于 dim >= 2 的 tensor 来说最后两维被看作矩阵的行和列,其余(如果存在)被看作 batch。

对于非矩阵(non-matrix)维度也是进行广播处理的,以 A.size() = (j, 1, m, n) 和 B.size() =(k, n, m) 为例,j x 1 和 k 是非矩阵维度,也就是 batch 维度,torch.matmul(A, B).size() = (j, k, m, m)。

input_d > 2 and other_d = 2

矩阵部分:(1, 2) * (2, 1)

A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, 1))

print_info(A, B)
>> A: tensor([[[3, 1]],
>> 
>>         [[1, 3]]])
>> B: tensor([[4],
>>         [3]])
>> A 的维度: 3,	 B 的维度: 2
>> A 的元素总数: 4,	 B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[[15]],
>> 
>>         [[13]]])

input_d > 2 and other_d = 1

这里可以看成单拎出 A 的最后 2 维与 B 做 input_d = 2 和 other_d = 1 的乘法:(1, 2) * (2, ),具体细节可以回看上面对应的部分。

A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, ))

print_info(A, B)
>> A: tensor([[[1, 4]],
>> 
>>         [[1, 4]]])
>> B: tensor([4, 1])
>> A 的维度: 3,	 B 的维度: 1
>> A 的元素总数: 4,	 B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[8],
>>         [8]])
>> torch.matmul(A, B).size() = torch.Size([2, 1])

input_d > 2 and other_d >2 (多维 Tensor)

广播部分:(2, 1, *, *) => (2, 2, *, *)。矩阵部分:(2, 1) * (1, 2)

A = torch.randint(0, 5, size=(2, 1, 2, 1))
B = torch.randint(0, 5, size=(2, 1, 2))

print_info(A, B)
>> A: tensor([[[[4],
>>           [4]]],
>> 
>> 
>>         [[[4],
>>           [0]]]])
>> B: tensor([[[1, 2]],
>> 
>>         [[3, 0]]])
>> A 的维度: 4,	 B 的维度: 3
>> A 的元素总数: 4,	 B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[ 4,  8],
>>           [ 4,  8]],
>> 
>>          [[12,  0],
>>           [12,  0]]],
>> 
>> 
>>         [[[ 4,  8],
>>           [ 0,  0]],
>> 
>>          [[12,  0],
>>           [ 0,  0]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 2, 2])

在往下翻之前不妨思考一下 torch.matmul(B, A).size() 等于多少。

print_info(B, A)
>> A: tensor([[[1, 2]],
>> 
>>         [[3, 0]]])
>> B: tensor([[[[4],
>>           [4]]],
>> 
>> 
>>         [[[4],
>>           [0]]]])
>> A 的维度: 3,	 B 的维度: 4
>> A 的元素总数: 4,	 B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[12]],
>> 
>>          [[12]]],
>> 
>> 
>>         [[[ 4]],
>> 
>>          [[12]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 1, 1])

拓展阅读

Broadcasting

你可能感兴趣的:(经验及避坑分享,深度学习,pytorch,人工智能)