PyTorch使用torch.sort()函数来筛选出前k个最大的项或者筛选出前k个最小的项

参考链接: torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)

筛选出前k个最小的数:

Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000026DB7CFD330>
>>>
>>> data = torch.randint(100,(15,))
>>> data
tensor([63, 48, 14, 47, 28,  5, 80, 68, 88, 61,  6, 84, 82, 87, 59])
>>> # 筛选出前k个最小的数
>>> k = 7
>>> a, idx1 = torch.sort(data)
>>> b, idx2 = torch.sort(idx1)
>>> a
tensor([ 5,  6, 14, 28, 47, 48, 59, 61, 63, 68, 80, 82, 84, 87, 88])
>>> data
tensor([63, 48, 14, 47, 28,  5, 80, 68, 88, 61,  6, 84, 82, 87, 59])
>>> data[idx2<k]
tensor([48, 14, 47, 28,  5,  6, 59])
>>>
>>>
>>>

筛选出前k个最大的数:

Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001846D60D330>
>>>
>>>
>>> data = torch.randint(100,(15,))
>>> data
tensor([63, 48, 14, 47, 28,  5, 80, 68, 88, 61,  6, 84, 82, 87, 59])
>>> # 筛选出前k个最大的数
>>>
>>> k = 7
>>> a, idx1 = torch.sort(data, descending=True)
>>> b, idx2 = torch.sort(idx1)
>>>
>>> a
tensor([88, 87, 84, 82, 80, 68, 63, 61, 59, 48, 47, 28, 14,  6,  5])
>>> data
tensor([63, 48, 14, 47, 28,  5, 80, 68, 88, 61,  6, 84, 82, 87, 59])
>>> data[idx2<k]
tensor([63, 80, 68, 88, 84, 82, 87])
>>>
>>>
>>>

简单证明:
PyTorch使用torch.sort()函数来筛选出前k个最大的项或者筛选出前k个最小的项_第1张图片
反向传播,只传播到前K个最大的项:

Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。

C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0

(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May  6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001F7F101D330>
>>>
>>> data = torch.randn(15,requires_grad=True)
>>> data
tensor([ 0.2824, -0.3715,  0.9088, -1.7601, -0.1806,  2.0937,  1.0406, -1.7651,
         1.1216,  0.8440,  0.1783,  0.6859, -1.5942, -0.2006, -0.4050],
       requires_grad=True)
>>>
>>> # 筛选出前k个最大的数
>>> k = 7
>>> a, idx1 = torch.sort(data, descending=True)
>>> b, idx2 = torch.sort(idx1)
>>> a
tensor([ 2.0937,  1.1216,  1.0406,  0.9088,  0.8440,  0.6859,  0.2824,  0.1783,
        -0.1806, -0.2006, -0.3715, -0.4050, -1.5942, -1.7601, -1.7651],
       grad_fn=<SortBackward>)
>>> b
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])
>>> idx1
tensor([ 5,  8,  6,  2,  9, 11,  0, 10,  4, 13,  1, 14, 12,  3,  7])
>>> idx2
tensor([ 6, 10,  3, 13,  8,  0,  2, 14,  1,  4,  7,  5, 12,  9, 11])
>>>
>>> a
tensor([ 2.0937,  1.1216,  1.0406,  0.9088,  0.8440,  0.6859,  0.2824,  0.1783,
        -0.1806, -0.2006, -0.3715, -0.4050, -1.5942, -1.7601, -1.7651],
       grad_fn=<SortBackward>)
>>> data
tensor([ 0.2824, -0.3715,  0.9088, -1.7601, -0.1806,  2.0937,  1.0406, -1.7651,
         1.1216,  0.8440,  0.1783,  0.6859, -1.5942, -0.2006, -0.4050],
       requires_grad=True)
>>> data[idx2<k]
tensor([0.2824, 0.9088, 2.0937, 1.0406, 1.1216, 0.8440, 0.6859],
       grad_fn=<IndexBackward>)
>>> sum_topK = data[idx2<k].sum()
>>> sum_topK
tensor(6.9770, grad_fn=<SumBackward0>)
>>>
>>> 2.0937+1.1216+1.0406+0.9088+0.8440+0.6859+0.2824
6.977000000000001
>>> data.grad
>>> print(data.grad)
None
>>> sum_topK.backward()
>>> print(data.grad)
tensor([1., 0., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 0.])
>>>
>>>
>>>

你可能感兴趣的:(PyTorch使用torch.sort()函数来筛选出前k个最大的项或者筛选出前k个最小的项)