Pytorch:.gather(1,)和.gather(0,)的区别

Pytorch中gather(1,)和.gather(0,)的区别?

在 PyTorch 中,.gather(dim, index) 函数用于根据给定的索引在指定的维度上获取张量的元素。其中,dim 表示要进行索引的维度,index 是包含索引值的一个张量。

下面以一个简单的例子来解释 .gather(1,) 和 .gather(0,) 的区别:

import torch

# 构造一个 3x4 的 Tensor,每个元素的值为对应位置的行列索引之和
x = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5]])

# 沿着行的方向使用 gather,将第 i 行的第 idx[i] 个元素取出来
idx = torch.tensor([1, 2, 0])
result = x.gather(1, idx.view(-1, 1))
print(result)  # prints: tensor([[1], [3], [2]])

# 沿着列的方向使用 gather,将第 j 列的第 idx[j] 个元素取出来
idx = torch.tensor([1, 2, 0, 3])
result = x.gather(0, idx.view(1, -1))
print(result)  # prints: tensor([[1, 3, 2, 3]])

在上面的例子中,我们首先构造了一个 3x4 的 Tensor x,每个元素的值为对应位置的行列索引之和。然后,我们分别使用 .gather(1, index) 和 .gather(0, index) 对其进行索引。

.gather(1, index) 表示沿着行的方向使用 gather,idx 张量的每个元素表示对应行的要取出的元素的列索引,dim=1 表示对列进行操作,那么返回的结果是一个列向量,它的第 i i i 个元素是第 i i i 行的第 i d x [ i ] idx[i] idx[i] 个元素。例如,当 idx=[1, 2, 0] 时,结果应该是 [1, 3, 2]。
.gather(0, index) 表示沿着列的方向使用 gather,idx 张量的每个元素表示对应列的要取出的元素的行索引,dim=0 表示对行进行操作,那么返回的结果是一个行向量,它的第 j j j 个元素是第 i d x [ j ] idx[j] idx[j] 列的第 j j j 个元素。例如,当 idx=[1, 2, 0, 3] 时,结果应该是 [1, 3, 2, 3]。
总之,.gather(dim, index) 函数可以根据指定的索引在指定的维度上对张量进行索引,但需要注意 dim 和 index 的值对于结果的影响。

你可能感兴趣的:(Pytorch学习手册,pytorch,人工智能,python)