【Pytorch】tensor可以按行、列筛选的大利器——torch.index_select()

因为自己遇到了将 Tensor

[1000,2] 截断为 [200,2]

的需求,故在网络上寻找对策。
先是看到了 tensor.gather() , 确实是不错的函数,但为了此需求稍显复杂。
终于发现了torch.index_select() ,好用到爆炸。

献上官方文档的例子:

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)   # 0指的是列,按indices[0,2]取得就是x的第一、第三行。
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)  # 1指的是行,按indices[0,2]取得就是x的第一、第三列。
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

非常好理解是吧。

然后我就用这个函数开心的解决了我的问题:

x = torch.randn(1000,2)
ind = []
for i in range(200):   # 取tensor的前200行
    ind.append(i)
indices = torch.LongTensor(ind).to(device)
out_rnn = torch.index_select(x, 0, indices).to(device)

你可能感兴趣的:(pytorch,神经网络,深度学习,人工智能,python)