图解PyTorch中的torch.gather函数

1、官方示例代码

import torch
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))

2、运行结果

图解PyTorch中的torch.gather函数_第1张图片

3、图解

图解PyTorch中的torch.gather函数_第2张图片

  4、验证

4.1、第一行的第一个元素和第二行的第二个元素

图解PyTorch中的torch.gather函数_第3张图片

  4.2、3个第一行的第一个元素和3个第二行的第个元素

 参考:

torch.gather — PyTorch 2.0 documentationhttps://pytorch.org/docs/stable/generated/torch.gather.html#torch.gather

你可能感兴趣的:(Python,pytorch,深度学习,python)