Pytorch 实现tf.gather()

1. 实现tf.gather

在pytorch中,实现 tf.gather 很简单,只需要使用 select。
select(dim, index) → Tensor

比如,

import numpy as np
a = np.array([[1],[2],[3],[4],[5]])
b = torch.from_numpy(a)
indices = [ 1, 2, 0]
b[indices]

Output:

tensor([[2],
        [3],
        [1]])

参考:

  1. How to implement an equivalent of tf.gather in pytorch

你可能感兴趣的:(学习Pytorch)