This article was original written by XRBLS, welcome re-post, first come with https://jinfagang.github.io . but please keep this copyright info, thanks, any question could be asked via wechat:
jintianiloveu
torch.gather
只是一个引子,别看它简单,但能引出很多问题。我们先来看看,它是如何工作的。假如我们有一个矩阵:
[[34, 4, 6],
[45, 6, 7]]
我们想要对它每一个位置的点进行重新排列应该怎么做呢?比如我要得到这么一个矩阵:
[[4, 6, 34],
[6, 7, 45]]
可以看到,我把每一行的(此时的axis=1)位置进行了变换。具体来说,用torch.gather
可以做这个事情:
r = torch.gather(a, 1, torch.tensor([[1, 2, 0], [1, 2, 0]]))
说白了,就是用一个矩阵来对它进行重排。那么到底在什么场合我们会用到这个函数呢?
其实一个很明显的作用就是在分类问题中,通过gather方法可以从一个矩阵里面挑选出最大值来完成分类任务。
之前有遇到一个onnx2trt的问题,但是本质上并不是由于它造成的,跟gather没有太大的关系。