[Pytorch]Broadcasting广播机制

文章目录

  • Broadcasting广播机制
    • Broadcastable
    • Broadcasting

Broadcasting广播机制

Broadcasting机制用于在不同维度的张量进行运算时进行维度的自动增加与扩展,Broadcasting机制使用的前提是两个参与运算的张量是可broadcastable的。

Broadcastable

怎样的两个向量是Broadcastable的,也就是可使用Broadcasting机制的?
规定右边的维度为小维度,即shape(2,32,13,13)中右边的13为最小的维度
张量x和张量y如果符合以下条件就是broadcastable的:

  • 将x的shape与y的shape按照最右边的最小维度对齐,从小维度开始向左看:对于每一个二者都有的维度,如果x在该维度的大小为1或y在该维度的大小为1或二者都是1,或者二者在该维度的大小都不是1但是二者维度大小相等。

例如:

example 1: x,y,z中任意两个张量是broadcastable的
x.shape = [2, 3, 4]
y.shape =     [3, 4]
z.shape =         [4]

对于x和z,按照小维度对齐后,在最小的维度上x和z的维度大小都是4,符合规则。
对于x和y,按照小维度对齐后,在最小的维度上x和y的维度大小都是4,在次小维度上x和y的维度大小都是3,也符合规则。

example 2:x,y不是broadcastable的
x.shape = [4, 5, 13, 13]
y.shape =      [4, 13, 1]

对于x和y,按照小维度对齐后,最小维度上y的维度大小为1,该维度符合,次小维度上x的维度大小和y的维度大小相等,该维度符合,但是在第三小维度上x的维度大小为5,而y的维度大小为4,二者非1且不相等,所以不符合broadcasting的规则。

Broadcasting

当两个broadcastable的符合broadcasting规则的张量在进行相加时,会自动使用维度增加和扩展运算使两个张量拥有相同的维度并计算,维度增加和扩展的规则如下:

  • 将x的shape与y的shape按照最右边的最小维度对齐,从小维度开始向左看:对于每一个二者都有的维度,将该维度扩展到二者的最大值;当其中一个张量的维度已经不存在,而另一个张量维度更多时,则给维度少的张量插入新的维度,使二者维度数量保持一致,同时也要将该维度扩展到二者中在该维度大小的最大值。

例如:

  • tensor a = [1, 2, 3]
  • tensor b = [10,
                      20,
                      30,
                      40]

将维度对齐后:

  • tensor b.shape = [4, 1]
  • tensor a.shape =     [3]

张量a和b扩展的详细过程:

  • 最小维度上,张量a的维度大小为3是最大值,张量b在该维度上扩展至大小3: [4,1] => [4,3]
  • 在次小维度上,张量a在该维度上不存在,则插入一个新的维度给张量a:[3] => [1,3]
  • 插入新的维度后在次小维度上,张量b在该维度上的维度大小为4是最大值,张量a在该维度上进行扩展至大小4:[1,3] => [4,3]
  • 维度扩展的方式就是在该维度上将原数据进行复制(逻辑复制,不开辟实际内存)
a = torch.tensor([1, 2, 3])
b = torch.tensor([
    [10],
    [20],
    [30],
    [40],
])
print(f"{a.shape} + {b.shape} = {(a+b).shape}")
print(a+b)

[Pytorch]Broadcasting广播机制_第1张图片

另一个示例:

x = torch.randint(0, 2, [3, 1])
y = torch.randint(0, 2, [4])
print(f"{x.shape} + {y.shape} = {(x+y).shape}")

在这里插入图片描述

你可能感兴趣的:(Pytorch,pytorch,python,人工智能)