gather()

bPytorch系列(1):torch.gather()_c-minus的博客-CSDN博客_torch。gather

a.gather(dim,b)  --以dim和b为索引,从a中提取数据组成新的tensor,shape与b一致。

其中a为tensor类型,b为longtensor类型。

dim = 0 ,按列提取;dim = 1 ,按行提取。

总结:确定好按行or列提取之后,b中每行数据代表从a中提取每行or列的第几个数。(若b第一行为[1,2],则表示提取a中第一行or列的第2,3个数。以此类推)

import torch
a = [
    [2, 3, 4, 5, 0, 0],
    [1, 4, 3, 0, 0, 0],
    [4, 2, 2, 5, 7, 0],
    [1, 0, 0, 0, 0, 0]
]
a = torch.tensor(a)
b = torch.LongTensor([[4],[3],[5],[1]])
#b之所以减1,是因为序列维度是从0开始计算的
out = torch.gather(a, 1, b-1)
out

你可能感兴趣的:(目标检测,目标检测)