Pytorch中torch.sort()和torch.argsort()函数解析

一. torch.sort()函数解析

1. 官网链接

torch.sort(),如下图所示:

2. torch.sort()函数解析

torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)

输入input,在dim维进行排序,默认是dim=-1对最后一维进行排序,descending表示是否按降序排,默认为False,输出排序后的值以及对应值在原输入imput中的下标

3. 代码举例

3.1 dim = -1 表示对每行中的元素进行升序排序,descending=False表示升序排序

x = torch.randn(3, 4)
sorted, indices = torch.sort(x)
x,sorted,indices

输出结果如下:
(tensor([[-1.3864,  0.5811, -0.1056, -0.3237],
         [-0.2136, -1.4806,  0.4986,  0.9382],
         [-0.2820,  0.1171, -0.3983, -0.8061]]),
 tensor([[-1.3864, -0.3237, -0.1056,  0.5811],
         [-1.4806, -0.2136,  0.4986,  0.9382],
         [-0.8061, -0.3983, -0.2820,  0.1171]]),
 tensor([[0, 3, 2, 1],
         [1, 0, 2, 3],
         [3, 2, 0, 1]]))

3.2 dim = 0 表示对每列中的元素进行升序排序,descending=False表示升序排序

x = torch.randn(3, 4)
sorted, indices = torch.sort(x,dim=0)
x,sorted,indices

输出结果如下:
(tensor([[ 0.7081,  1.0502,  2.0434, -0.2592],
         [ 1.2052,  0.8809,  0.5771,  1.2978],
         [-1.5873, -0.4808, -2.1774, -0.2503]]),
 tensor([[-1.5873, -0.4808, -2.1774, -0.2592],
         [ 0.7081,  0.8809,  0.5771, -0.2503],
         [ 1.2052,  1.0502,  2.0434,  1.2978]]),
 tensor([[2, 2, 2, 0],
         [0, 1, 1, 2],
         [1, 0, 0, 1]]))

3.3 dim = 0 表示对每列中的元素进行降序排序,descending=True表示降序排序

x = torch.randn(3, 4)
sorted, indices = torch.sort(x,dim=0,descending=True)
x,sorted,indices

输出结果如下:
(tensor([[ 0.9142, -0.2178,  0.5602,  2.3951],
         [-0.6977,  0.4915,  0.3988,  0.6406],
         [ 0.4880,  1.1646, -0.3466,  0.5801]]),
 tensor([[ 0.9142,  1.1646,  0.5602,  2.3951],
         [ 0.4880,  0.4915,  0.3988,  0.6406],
         [-0.6977, -0.2178, -0.3466,  0.5801]]),
 tensor([[0, 2, 0, 0],
         [2, 1, 1, 1],
         [1, 0, 2, 2]]))

3.4 dim = 1 表示对每行中的元素进行降序排序,descending=True表示降序排序

x = torch.randn(3, 4)
sorted, indices = torch.sort(x,dim=1,descending=True)
x,sorted,indices

输出结果如下:
(tensor([[-0.3048, -1.9915, -0.0888,  0.3881],
         [ 1.0677, -1.3520,  0.2944, -0.0772],
         [-0.9409, -0.9630, -0.7946,  1.4400]]),
 tensor([[ 0.3881, -0.0888, -0.3048, -1.9915],
         [ 1.0677,  0.2944, -0.0772, -1.3520],
         [ 1.4400, -0.7946, -0.9409, -0.9630]]),
 tensor([[3, 2, 0, 1],
         [0, 2, 3, 1],
         [3, 2, 0, 1]]))

二.torch.argsort()函数解析

1. 官网链接

torch.argsort(),如下图所示:

image.png

2. torch.argsort()函数解析

用法跟上面torch.sort()函数一样,不同的是torch.argsort()返回只是排序后的值所对应原输入input的下标,即torch.sort()返回的indices

3. 代码举例

dim = 1 表示对每行中的元素进行降序排序,descending=True表示降序排序,输出结果为返回排序后的值所对应原输入input的下标indices

x = torch.randn(3, 4)
indices = torch.argsort(x,dim=1,descending=True)
x,indices

输出结果如下:
(tensor([[-0.6069, -0.9252, -0.9177,  0.6997],
         [ 0.3245, -0.0665,  0.4600,  0.0722],
         [-1.0662,  2.2669, -0.1171, -0.9208]]),
 tensor([[3, 0, 2, 1],
         [2, 0, 3, 1],
         [1, 2, 3, 0]]))

参考知识文章

你可能感兴趣的:(Pytorch中torch.sort()和torch.argsort()函数解析)