torch.topk与torch.sort用法

torch.topk与torch.sort的用法类似,topk是在sort的基础上取前k个值。topk默认是降序(从大到小);而sort是默认是升序(从小到大)。具体用法如下:

1、torch.topk用法

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

沿给定dim维度返回输入张量inputk 个最大值。 如果不指定dim,则默认为input的最后一维。 如果为largestFalse ,则返回最小的 k 个值。

返回一个元组 (values,indices),其中indices是原始输入张量input中测元素下标。 如果设定布尔值sorted 为_True_,将会确保返回的 k 个值被排序。

参数:

  • input (Tensor) – 输入张量

  • k (int) – “top-k”中的k

  • dim (int, optional) – 排序的维

  • largest (bool, optional) – 布尔值,控制返回最大或最小值

  • sorted (bool, optional) – 布尔值,控制返回值是否排序

  • out (tuple, optional) – 可选输出张量 (Tensor, LongTensor) output buffers

    >>> import torch
    >>> x = torch.arange(1, 6) 
    >>> x
    tensor([1, 2, 3, 4, 5])
    >>> torch.topk(x,3)
    torch.return_types.topk(
    values=tensor([5, 4, 3]),
    indices=tensor([4, 3, 2]))
    >>> values, indices = torch.topk(x, 3)
    >>> values
    tensor([5, 4, 3])
    >>> indices
    tensor([4, 3, 2])
    >>> torch.topk(x, 1)
    torch.return_types.topk(
    values=tensor([5]),
    indices=tensor([4]))
    >>> torch.topk(x, 1, largest=False) # 最小值,从小到大
    torch.return_types.topk(
    values=tensor([1]),
    indices=tensor([0]))
    

2、torch.sort用法

torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)

对输入张量input沿着指定维按升序排序。如果不给定dim,则默认为输入的最后一维。如果指定参数descendingTrue,则按降序排序

返回元组 (sorted_tensor, sorted_indices) , sorted_indices 为原始输入中的下标。

参数:

  • input (Tensor) – 要对比的张量
  • dim (int, optional) – 沿着此维排序
  • descending (bool, optional) – 布尔值,控制升降排序
  • out (tuple, optional) – 输出张量。必须为ByteTensor或者与第一个参数tensor相同类型。
>>> import torch
>>> x = torch.arange(1, 6) 
>>> torch.sort(x)
torch.return_types.sort(
values=tensor([1, 2, 3, 4, 5]),
indices=tensor([0, 1, 2, 3, 4]))
>>> values, indices = torch.sort(x)
>>> values
tensor([1, 2, 3, 4, 5])
>>> indices
tensor([0, 1, 2, 3, 4])

参考:

torch.topk & sort官方文档

你可能感兴趣的:(pytorch,pytorch)