【Torch API】pytorch 中index_add()函数详解

import torch

# create a tensor of zeros with shape (3, 4)
t = torch.zeros((3, 4))

# create an index tensor of shape (2,)
index = torch.tensor([0, 2])

# create a tensor to add with shape (2, 4)
src = torch.ones((2, 4))

# add the src tensor to t at the indices specified in the index tensor
t.index_add(0, index, src)

print(t)

输出结果为:

tensor([[1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [1., 1., 1., 1.]])

解释:
上述代码,使用了 index_add() 方法将 src 张量添加到 t 张量中。具体而言,下面这行代码:

t.index_add(0, index, src)

完成了相加的操作。下面是每个参数的含义:

0:指定执行索引操作的维度。在本例中,我们要将 src 张量添加到由 index 张量指定的 t 张量的行上,因此设置 dim=0。

index:指定要将 src 张量添加到哪些位置的索引。在本例中,我们要将 src 张量添加到 t 的第 0 行和第 2 行,因此设置 index=torch.tensor([0, 2])。

src:要添加到 t 的张量。在本例中,src 是一个全部为 1 的张量,形状为 (2, 4)。

index_add() 方法的工作原理是将 src 的每一行加到由 index 张量指定的 t 张量的相应行上。在本例中,指定了 index=torch.tensor([0, 2]),因此 src 的第一行(全为 1)会被加到 t 的第一行,第二行(同样全为 1)会被加到 t 的第三行。结果是一个张量,其中第一行和第三行为 1,其它地方为 0。

你可能感兴趣的:(基础知识,pytorch)