torch.sort()
方法原型:
torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, LongTensor)
返回值:
A tuple of (sorted_tensor, sorted_indices) is returned, where the sorted_indices are the indices of the elements in the original input tensor.
import torch
x = torch.randn(3,4)
x #初始值,始终不变
tensor([[-0.9950, -0.6175, -0.1253, 1.3536],
[ 0.1208, -0.4237, -1.1313, 0.9022],
[-1.1995, -0.0699, -0.4396, 0.8043]])
sorted, indices = torch.sort(x) #按行从小到大排序
sorted
tensor([[-0.9950, -0.6175, -0.1253, 1.3536],
[-1.1313, -0.4237, 0.1208, 0.9022],
[-1.1995, -0.4396, -0.0699, 0.8043]])
indices
tensor([[0, 1, 2, 3],
[2, 1, 0, 3],
[0, 2, 1, 3]])
sorted, indices = torch.sort(x, descending=True) #按行从大到小排序 (即反序)
sorted
tensor([[ 1.3536, -0.1253, -0.6175, -0.9950],
[ 0.9022, 0.1208, -0.4237, -1.1313],
[ 0.8043, -0.0699, -0.4396, -1.1995]])
indices
tensor([[3, 2, 1, 0],
[3, 0, 1, 2],
[3, 1, 2, 0]])
sorted, indices = torch.sort(x, dim=0) #按列从小到大排序
sorted
tensor([[-1.1995, -0.6175, -1.1313, 0.8043],
[-0.9950, -0.4237, -0.4396, 0.9022],
[ 0.1208, -0.0699, -0.1253, 1.3536]])
indices
tensor([[2, 0, 1, 2],
[0, 1, 2, 1],
[1, 2, 0, 0]])
sorted, indices = torch.sort(x, dim=0, descending=True) #按列从大到小排序
sorted
tensor([[ 0.1208, -0.0699, -0.1253, 1.3536],
[-0.9950, -0.4237, -0.4396, 0.9022],
[-1.1995, -0.6175, -1.1313, 0.8043]])
indices
tensor([[1, 2, 0, 0],
[0, 1, 2, 1],
[2, 0, 1, 2]])
官方文档:https://pytorch.org/docs/stable/torch.html