PyTorch搜索Tensor指定维度的前K大个(K小个)元素--------(torch.topk)命令参数详解及举例

torch.topk

语法

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out = None)

作用

返回输入tensorinput中,在给定的维度dimk个最大的元素。

如果dim没有给定,那么选择输入input的最后一维。

如果largest = False,那么返回k个最小的元素。

返回一个namedtuple类型的元组(values, indices),其中indices是指元素在原数组中的索引。

sorted = True, 则返回的k个元素是有序的。

Parameters

  • input (Tensor) – the input tensor
    输入的张量

  • k (int) – the k in “top-k”
    返回的k的值

  • dim(int, optional) – the dimension to sort along
    指定的排序的维度 ,如果dim没有给定,那么选择输入input的最后一维。 dim若为-1,文档未说明,但是根据实操效果,应该也是对最后一维进行search。
    如shape为Batch_size x p x q,返回结果为Batch_size x p x k

  • largest(bool, optional) – controls whether to return largest or smallest elements
    True返回最大值,False返回最小值。

  • sorted(bool, optional) – controls whether to return the elements in sorted order
    控制返回的元素是否排序。

例子

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))

你可能感兴趣的:(PyTorch,笔记,python,pytorch)