关于函数torch.topk用法的思考

文章目录

      • 1. 沿着`dim=0`
      • 2. 沿着`dim=1`
      • 3. 沿着`dim=2`
      • 4. 总结

开始介绍之前先来点哲理性的思考,为什么函数 torch.topk,他的名字会叫 topk呢?

个人认为名称是来源于 “top k”,在这种情况下,它表示 “前 k 个最大值”。

假设我们有一个形状为 ( 2 , 3 , 4 ) (2, 3, 4) (2,3,4) 的三维张量 A A A,如下所示:

A = torch.tensor([[[ 1,  3,  5,  7],
                   [ 2,  4,  6,  8],
                   [ 9, 11, 13, 15]],
                  [[16, 18, 20, 22],
                   [17, 19, 21, 23],
                   [10, 12, 14, 24]]])

1. 沿着dim=0

沿着 dim=0(即在子矩阵之间进行比较):

k = 1
topk_values, topk_indices = torch.topk(A, k=k, dim=0)
A = torch.tensor([[[ 1,  3,  5,  7],
                   [ 2,  4,  6,  8],
                   [ 9, 11, 13, 15]],
                  [[16, 18, 20, 22],
                   [17, 19, 21, 23],
                   [10, 12, 14, 24]]])

沿着 dim=0(即在子矩阵之间进行比较):

topk_values = tensor([[[16, 18, 20, 22],
                        [17, 19, 21, 23],
                        [10, 12, 14, 24]]])

topk_indices = tensor([[[1, 1, 1, 1],
                         [1, 1, 1, 1],
                         [1, 1, 1, 1]]])

那此时我们令k = 3会发生什么?很显然,我们并没有三个子矩阵,所以此时程序会报错。给大家看一下程序的错误:

Traceback (most recent call last):
  File "E:\Learning_Material\Junior_Second_Semester\Adademic_Research\BusterNet_pytorch-master\test_2023_4_7.py", line 14, in <module>
    topk_values, topk_indices = torch.topk(A, k=k, dim=0)
RuntimeError: selected index k out of range

2. 沿着dim=1

沿着 dim=1(即在行之间进行比较):

k = 2
topk_values, topk_indices = torch.topk(A, k=k, dim=1)
topk_values = tensor([[[ 9, 11, 13, 15],
         			   [ 2,  4,  6,  8]],

        			   [[17, 19, 21, 24],
         				[16, 18, 20, 23]]])

topk_indices = tensor([[[2, 2, 2, 2],
         				[1, 1, 1, 1]],

        			   [[1, 1, 1, 2],
         				[0, 0, 0, 1]]])

在沿着行进行比较的情况下,比较是不会在子矩阵,也就是更高维度上发生的。仅仅在子矩阵的维度上比较一个子矩阵中最大的行数。

3. 沿着dim=2

沿着 dim=2(即在列之间进行比较):

k = 2
topk_values, topk_indices = torch.topk(A, k=k, dim=2)
topk_values = tensor([[[ 7,  5],
         			   [ 8,  6],
         			   [15, 13]],

        			  [[22, 20],
         		   	   [23, 21],
         			   [24, 14]]])

topk_indices = tensor([[[3, 2],
         				[3, 2],
         				[3, 2]],

        			   [[3, 2],
         				[3, 2],
         				[3, 2]]])

沿着dim=2比较,此时就不会牵扯到更高的两个维度,只会在最后一个维度之内进行排序比较

4. 总结

没有什么别的经验,希望对大家有用就好!

你可能感兴趣的:(数字图像处理,算法,机器学习,人工智能,线性代数)