torch.topk
,他的名字会叫 topk
呢?
个人认为名称是来源于 “top k”,在这种情况下,它表示 “前 k 个最大值”。
假设我们有一个形状为 ( 2 , 3 , 4 ) (2, 3, 4) (2,3,4) 的三维张量 A A A,如下所示:
A = torch.tensor([[[ 1, 3, 5, 7],
[ 2, 4, 6, 8],
[ 9, 11, 13, 15]],
[[16, 18, 20, 22],
[17, 19, 21, 23],
[10, 12, 14, 24]]])
dim=0
沿着 dim=0(即在子矩阵之间进行比较):
k = 1
topk_values, topk_indices = torch.topk(A, k=k, dim=0)
A = torch.tensor([[[ 1, 3, 5, 7],
[ 2, 4, 6, 8],
[ 9, 11, 13, 15]],
[[16, 18, 20, 22],
[17, 19, 21, 23],
[10, 12, 14, 24]]])
沿着 dim=0
(即在子矩阵之间进行比较):
topk_values = tensor([[[16, 18, 20, 22],
[17, 19, 21, 23],
[10, 12, 14, 24]]])
topk_indices = tensor([[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]]])
那此时我们令k = 3
会发生什么?很显然,我们并没有三个子矩阵,所以此时程序会报错。给大家看一下程序的错误:
Traceback (most recent call last):
File "E:\Learning_Material\Junior_Second_Semester\Adademic_Research\BusterNet_pytorch-master\test_2023_4_7.py", line 14, in <module>
topk_values, topk_indices = torch.topk(A, k=k, dim=0)
RuntimeError: selected index k out of range
dim=1
沿着 dim=1
(即在行之间进行比较):
k = 2
topk_values, topk_indices = torch.topk(A, k=k, dim=1)
topk_values = tensor([[[ 9, 11, 13, 15],
[ 2, 4, 6, 8]],
[[17, 19, 21, 24],
[16, 18, 20, 23]]])
topk_indices = tensor([[[2, 2, 2, 2],
[1, 1, 1, 1]],
[[1, 1, 1, 2],
[0, 0, 0, 1]]])
在沿着行进行比较的情况下,比较是不会在子矩阵,也就是更高维度上发生的。仅仅在子矩阵的维度上比较一个子矩阵中最大的行数。
dim=2
沿着 dim=2
(即在列之间进行比较):
k = 2
topk_values, topk_indices = torch.topk(A, k=k, dim=2)
topk_values = tensor([[[ 7, 5],
[ 8, 6],
[15, 13]],
[[22, 20],
[23, 21],
[24, 14]]])
topk_indices = tensor([[[3, 2],
[3, 2],
[3, 2]],
[[3, 2],
[3, 2],
[3, 2]]])
沿着dim=2
比较,此时就不会牵扯到更高的两个维度,只会在最后一个维度之内进行排序比较
没有什么别的经验,希望对大家有用就好!