torch.sort

torch.sort()
方法原型:
torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)
返回值:
A tuple of (sorted_tensor, sorted_indices) is returned, where the sorted_indices are the indices of the elements in the original input tensor.

参数

  • input (Tensor) – the input tensor
    形式上与 numpy.narray 类似
  • dim (int, optional) – the dimension to sort along
    维度,对于二维数据:dim=0 按列排序,dim=1 按行排序,默认 dim=1
  • descending (bool, optional) – controls the sorting order (ascending or descending)
    降序,descending=True 从大到小排序,descending=False 从小到大排序,默认 descending=Flase

实例

import torch
x = torch.randn(3,4)
x  #初始值,始终不变
tensor([[-0.9950, -0.6175, -0.1253,  1.3536],
        [ 0.1208, -0.4237, -1.1313,  0.9022],
        [-1.1995, -0.0699, -0.4396,  0.8043]])
sorted, indices = torch.sort(x)  #按行从小到大排序
sorted
tensor([[-0.9950, -0.6175, -0.1253,  1.3536],
        [-1.1313, -0.4237,  0.1208,  0.9022],
        [-1.1995, -0.4396, -0.0699,  0.8043]])
indices
tensor([[0, 1, 2, 3],
        [2, 1, 0, 3],
        [0, 2, 1, 3]])
sorted, indices = torch.sort(x, descending=True)  #按行从大到小排序 (即反序)
sorted
tensor([[ 1.3536, -0.1253, -0.6175, -0.9950],
        [ 0.9022,  0.1208, -0.4237, -1.1313],
        [ 0.8043, -0.0699, -0.4396, -1.1995]])
indices
tensor([[3, 2, 1, 0],
        [3, 0, 1, 2],
        [3, 1, 2, 0]])
sorted, indices = torch.sort(x, dim=0)  #按列从小到大排序
sorted
tensor([[-1.1995, -0.6175, -1.1313,  0.8043],
        [-0.9950, -0.4237, -0.4396,  0.9022],
        [ 0.1208, -0.0699, -0.1253,  1.3536]])
indices
tensor([[2, 0, 1, 2],
        [0, 1, 2, 1],
        [1, 2, 0, 0]])
sorted, indices = torch.sort(x, dim=0, descending=True)  #按列从大到小排序
sorted
tensor([[ 0.1208, -0.0699, -0.1253,  1.3536],
        [-0.9950, -0.4237, -0.4396,  0.9022],
        [-1.1995, -0.6175, -1.1313,  0.8043]])
indices
tensor([[1, 2, 0, 0],
        [0, 1, 2, 1],
        [2, 0, 1, 2]])

官方文档:https://pytorch.org/docs/stable/torch.html

你可能感兴趣的:(pytorch)