『PyTorch』张量和函数之gather()函数

文章目录

  • PyTorch中的选择函数
    • gather()函数
  • 参考文献

PyTorch中的选择函数

gather()函数

import torch
a = torch.arange(1, 16).reshape(5, 3)
"""
result:
a = [[1, 2, 3],
      [4, 5, 6],
      [7, 8, 9],
      [10, 11, 12],
      [13, 14, 15]]
"""

# 定义两个index
b = torch.tensor([[0, 1, 2], [2, 3, 4], [0, 2, 4]])
c = torch.tensor([[1, 2, 0, 2, 1], [1, 2, 1, 0, 0]])

# axis=0
output1 = a.gather(0, b)
"""
result:
[[1, 5, 9],
[7, 11, 15],
[1, 8, 15]]
"""

# axis=1
output2 = a.gather(1, c)
"""
result:
[[2, 3, 1, 3, 2],
[5, 6, 5, 4, 4]]
"""

『PyTorch』张量和函数之gather()函数_第1张图片

参考文献

1、理解pytorch几个高级选择函数(如gather)
2、图解PyTorch的torch.gather函数

你可能感兴趣的:(pytorch,python)