参考链接: torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
筛选出前k个最小的数:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x0000026DB7CFD330>
>>>
>>> data = torch.randint(100,(15,))
>>> data
tensor([63, 48, 14, 47, 28, 5, 80, 68, 88, 61, 6, 84, 82, 87, 59])
>>> # 筛选出前k个最小的数
>>> k = 7
>>> a, idx1 = torch.sort(data)
>>> b, idx2 = torch.sort(idx1)
>>> a
tensor([ 5, 6, 14, 28, 47, 48, 59, 61, 63, 68, 80, 82, 84, 87, 88])
>>> data
tensor([63, 48, 14, 47, 28, 5, 80, 68, 88, 61, 6, 84, 82, 87, 59])
>>> data[idx2<k]
tensor([48, 14, 47, 28, 5, 6, 59])
>>>
>>>
>>>
筛选出前k个最大的数:
Microsoft Windows [版本 10.0.18363.1256]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001846D60D330>
>>>
>>>
>>> data = torch.randint(100,(15,))
>>> data
tensor([63, 48, 14, 47, 28, 5, 80, 68, 88, 61, 6, 84, 82, 87, 59])
>>> # 筛选出前k个最大的数
>>>
>>> k = 7
>>> a, idx1 = torch.sort(data, descending=True)
>>> b, idx2 = torch.sort(idx1)
>>>
>>> a
tensor([88, 87, 84, 82, 80, 68, 63, 61, 59, 48, 47, 28, 14, 6, 5])
>>> data
tensor([63, 48, 14, 47, 28, 5, 80, 68, 88, 61, 6, 84, 82, 87, 59])
>>> data[idx2<k]
tensor([63, 80, 68, 88, 84, 82, 87])
>>>
>>>
>>>
Microsoft Windows [版本 10.0.18363.1316]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate ssd4pytorch1_2_0
(ssd4pytorch1_2_0) C:\Users\chenxuqi>python
Python 3.7.7 (default, May 6 2020, 11:45:54) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001F7F101D330>
>>>
>>> data = torch.randn(15,requires_grad=True)
>>> data
tensor([ 0.2824, -0.3715, 0.9088, -1.7601, -0.1806, 2.0937, 1.0406, -1.7651,
1.1216, 0.8440, 0.1783, 0.6859, -1.5942, -0.2006, -0.4050],
requires_grad=True)
>>>
>>> # 筛选出前k个最大的数
>>> k = 7
>>> a, idx1 = torch.sort(data, descending=True)
>>> b, idx2 = torch.sort(idx1)
>>> a
tensor([ 2.0937, 1.1216, 1.0406, 0.9088, 0.8440, 0.6859, 0.2824, 0.1783,
-0.1806, -0.2006, -0.3715, -0.4050, -1.5942, -1.7601, -1.7651],
grad_fn=<SortBackward>)
>>> b
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
>>> idx1
tensor([ 5, 8, 6, 2, 9, 11, 0, 10, 4, 13, 1, 14, 12, 3, 7])
>>> idx2
tensor([ 6, 10, 3, 13, 8, 0, 2, 14, 1, 4, 7, 5, 12, 9, 11])
>>>
>>> a
tensor([ 2.0937, 1.1216, 1.0406, 0.9088, 0.8440, 0.6859, 0.2824, 0.1783,
-0.1806, -0.2006, -0.3715, -0.4050, -1.5942, -1.7601, -1.7651],
grad_fn=<SortBackward>)
>>> data
tensor([ 0.2824, -0.3715, 0.9088, -1.7601, -0.1806, 2.0937, 1.0406, -1.7651,
1.1216, 0.8440, 0.1783, 0.6859, -1.5942, -0.2006, -0.4050],
requires_grad=True)
>>> data[idx2<k]
tensor([0.2824, 0.9088, 2.0937, 1.0406, 1.1216, 0.8440, 0.6859],
grad_fn=<IndexBackward>)
>>> sum_topK = data[idx2<k].sum()
>>> sum_topK
tensor(6.9770, grad_fn=<SumBackward0>)
>>>
>>> 2.0937+1.1216+1.0406+0.9088+0.8440+0.6859+0.2824
6.977000000000001
>>> data.grad
>>> print(data.grad)
None
>>> sum_topK.backward()
>>> print(data.grad)
tensor([1., 0., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 0.])
>>>
>>>
>>>