tensor 切片操作(用tensor去切片)

当tensor作为切片索引时需要注意tensor的类型(tensors used as indices must be long, byte or bool tensors)

import torch
import numpy as np

y = np.array([1,3])
a = torch.tensor(y, dtype=torch.long)
x = torch.rand(5,3)
print(x)
print(y)
print(x[a,0])   #数字0表示取第几列,a(tensor)表示取第几个元素

out :tensor([[0.3472, 0.7059, 0.2417],
        [0.0198, 0.1794, 0.9050],
        [0.7338, 0.8713, 0.9565],
        [0.6389, 0.4281, 0.0716],
        [0.8669, 0.9940, 0.5680]])
  [1  2]
tensor([0.0198, 0.7338])

类似的numpy也可以这么切片,获取想要的数据

你可能感兴趣的:(python,pytorch)