PyTorch张量(tensor)之间的广播

最近在阅读一些源码时总是会被它们中的一些涉及张量操作的代码迷惑,尤其是不同shape之间的张量之间的广播。所以今天就写一篇关于张量(tensor)的学习笔记来记录一下。

广播(broadcasting)是一种在不同形状的张量之间执行按元素运算的机制。它允许 PyTorch 在不同形状的张量之间执行按元素运算,而无需显式地调整它们的形状。

当我们在不同形状的张量之间执行按元素运算时,PyTorch 会自动将它们扩展到相同的形状。例如,如果我们有一个形状为[5,1]的张量x和一个形状为[1,1]的张量y,那么当我们执行 x+y时,PyTorch 会自动将y扩展到与x相同的形状[5,1],然后执行按元素加法。

import torch
x = torch.randn(5, 1)
y = torch.randn(1, 1)
z = x + y
print(z.shape) # torch.Size([5, 1])
#这句代码可以查看y广播到能够于x进行张量运算后的结果
print("y expanded:", y.expand_as(x))

还有一种例子是如果我们有一个形状为 [5, 1] 的张量 x 和一个形状为 [1, 4] 的张量 y,那么当我们执行 x + y 时,PyTorch 会自动将它们扩展到相同的形状 [5, 4],然后执行按元素加法。这样也可以进行广播,但是无法使用y.expand_as(x)。

import torch
x = torch.randn(5, 1)
y = torch.randn(1, 4)
z = x + y
print(z.shape) # torch.Size([5, 4])

也就是说当两个张量的维度相同时,如果对应的维度都不相同,但其中只要有一个张量在维度的维度数是1就可以进行广播。

当两个张量的维度数不相同时,比如1个张量是3维的一个张量是2维的又该怎么做呢?下面是一个例子。

import torch
x = torch.randn(2, 3, 2)
y = torch.randn(1, 3)
z = x + y # RuntimeError: The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 2. Target sizes: [2, 3, 2]. Tensor sizes: [1, 3]

这样就不可以进行广播操作,因为PyTorch在比较两个张量是否兼容的时候会从最后一个维度开始进行判断,给出的例子中,第一个张量的最后一个维度大小为 2,第二个张量的最后一个维度大小为 3,所以它们在这个维度上的大小不相同,且都不等于 1。因此,这两个张量的形状是不兼容的,无法进行广播。

如果把y的第二维改成2或1那么都可以进行广播操作,相同的y的第一维也只能是1或者3。

import torch
x = torch.randn(2, 3, 2)
y = torch.randn(1, 2)
z = x + y

也就是说在不同维度的张量广播中我们需要对维度较小的那个张量从最后向前来审查它是否能与维度较大的对应。

当我们使用bool类型的张量去筛选其它张量时,假如y是bool类型它的shape是(2,3),要筛选的张量的shape为(2,3,2),这时y是可以对进行筛选的。

import torch
x = torch.randn(2, 3, 2)
y = torch.randint(0, 2,(2,3),dtype=torch.bool)

z = x[y]
print(z.shape)#(N,2)其中N代表z中元素表达为True的元素个数

你可能感兴趣的:(pytorch,人工智能,python)