pytorch中topk()用法的测试与个人理解

参数介绍:

直接官网的介绍topk()
pytorch中topk()用法的测试与个人理解_第1张图片

  • input:就是输入的tensor,也就是要取topk的张量
  • k:就是取前k个最大的值。
  • dim:就是在哪一维来取这k个值。
  • lagest:默认是true表示取前k大的值,false则表示取前k小的值
  • sorted:是否按照顺序输出,默认是true。
  • out : 可选输出张量 (Tensor, LongTensor)

直接上代码:


首先研究一下dim和k这两个最重要的参数:

import torch
seed = 0
torch.manual_seed(seed)
a = torch.randint(1,10,(3,4,4))

print(a)
tensor([[[9, 1, 3, 7],
         [8, 7, 8, 2],
         [2, 1, 9, 3],
         [7, 4, 2, 3]],

        [[1, 1, 6, 4],
         [9, 3, 9, 3],
         [9, 6, 8, 9],
         [7, 1, 2, 1]],

        [[9, 7, 2, 7],
         [2, 9, 9, 8],
         [3, 4, 8, 8],
         [2, 6, 5, 8]]])
values , indices = a.topk(2,dim=0)
print(values.shape)
print(values)
print(indices.shape)
print(indices)
torch.Size([2, 4, 4])
tensor([[[9, 7, 6, 7],
         [9, 9, 9, 8],
         [9, 6, 9, 9],
         [7, 6, 5, 8]],

        [[9, 1, 3, 7],
         [8, 7, 9, 3],
         [3, 4, 8, 8],
         [7, 4, 2, 3]]])
torch.Size([2, 4, 4])
tensor([[[0, 2, 1, 0],
         [1, 2, 1, 2],
         [1, 1, 0, 1],
         [0, 2, 2, 2]],

        [[2, 0, 0, 2],
         [0, 0, 2, 1],
         [2, 2, 1, 2],
         [1, 0, 0, 0]]])

**这里说一下我的理解 **

首先dim=0这个参数表示在某一维取topk,在我的代码中就是取前2个。首先看输出的values和indices的张量形状:(2,4,4)这里可以结合下面两个实验总结出规律,dim取几,输出结果的形状就是:其他维度不变,对应维度变成k。

现在dim=0最后的输出就是要变成(2,4,4)也就是之前第一维中保留两个最大的。

看values的值,讲一下第一行元素[9,7,6,7]是如何得来的:
因为dim=0所以要从第0维来看,将数据分成3份,分别是:

1. [[9, 1, 3, 7],
    [8, 7, 8, 2],
    [2, 1, 9, 3],
    [7, 4, 2, 3]]
2. [[1, 1, 6, 4],
    [9, 3, 9, 3],
    [9, 6, 8, 9],
    [7, 1, 2, 1]]
3. [[9, 7, 2, 7],
    [2, 9, 9, 8],
    [3, 4, 8, 8],
    [2, 6, 5, 8]]

要以这三个tensor为单位进行topk的筛选,首先比较每一个tensor的第一行,因为参数k为2,所以就要找到这3组元素中的最大值和次大值,作为最后的输出。因此最大值就是[9,7,6,7],次大值为:[9,1,3,7]这样就完成了筛选。索引值也就是当前位置处的元素,是来自于这三个元素中的哪一个。我认为把这个看懂后面就可以迎刃而解,大家可以仔细理解一下不太懂的话也没关系,看完后面两个可能这个就懂了。

values1 , indices1 = a.topk(2,dim=1)
print(values1)
print(indices1)
torch.Size([3, 2, 4])
tensor([[[9, 7, 9, 7],
         [8, 4, 8, 3]],

        [[9, 6, 9, 9],
         [9, 3, 8, 4]],

        [[9, 9, 9, 8],
         [3, 7, 8, 8]]])
torch.Size([3, 2, 4])
tensor([[[0, 1, 2, 0],
         [1, 3, 1, 2]],

        [[1, 2, 1, 2],
         [2, 1, 2, 0]],

        [[0, 1, 1, 2],
         [2, 0, 2, 3]]])

这个例子是dim=1时,类比于dim=0的情况。这里是对第一维进行筛选操作。需要注意的是这里第0维的三个元素是分开操作的。这里我提供一种我自己的理解思路大家借鉴。首先还是按照第0维将tensor分为3块

1. [[9, 1, 3, 7],
    [8, 7, 8, 2],
    [2, 1, 9, 3],
    [7, 4, 2, 3]]
2. [[1, 1, 6, 4],
    [9, 3, 9, 3],
    [9, 6, 8, 9],
    [7, 1, 2, 1]]
3. [[9, 7, 2, 7],
    [2, 9, 9, 8],
    [3, 4, 8, 8],
    [2, 6, 5, 8]]

这里每一块中的第0维就是总体tensor的第一维,从第0维来看就是4个14的向量,因此就是对这4向量取最大值和次大值。也就是在这个44的张量中选出对应位置的最大值和次大值。例如第一块中筛选出的结果就是[9,7,9,7]和[8,4,8,3]其他同理,索引值表示当前位置处的值是来自哪一个向量。

values2 , indices2 = a.topk(2,dim=2)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[9, 7],
         [8, 8],
         [9, 3],
         [7, 4]],

        [[6, 4],
         [9, 9],
         [9, 9],
         [7, 2]],

        [[9, 7],
         [9, 9],
         [8, 8],
         [8, 6]]])
torch.Size([3, 4, 2])
tensor([[[0, 3],
         [0, 2],
         [2, 3],
         [0, 1]],

        [[2, 3],
         [0, 2],
         [3, 0],
         [0, 2]],

        [[0, 1],
         [1, 2],
         [2, 3],
         [3, 1]]])

类比前两种情况的思考方式,这里的操作就是对整个张量最内层做的操作,也就是整体张量形状(3,4,4)中的4这个4就是最内层每一个一维向量中的4个元素,取对应的最大值和次大值,应该也容易理解。大家可以对比着三种情况的输入输出加以理解。
另外,k参数默认是最后一维

然后研究一下lagest参数:
直接用最后一维

values2 , indices2 = a.topk(2,dim=2,largest=False)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[1, 3],
         [2, 7],
         [1, 2],
         [2, 3]],

        [[1, 1],
         [3, 3],
         [6, 8],
         [1, 1]],

        [[2, 7],
         [2, 8],
         [3, 4],
         [2, 5]]])
torch.Size([3, 4, 2])
tensor([[[1, 2],
         [3, 1],
         [1, 0],
         [2, 3]],

        [[1, 0],
         [1, 3],
         [1, 2],
         [1, 3]],

        [[2, 3],
         [0, 3],
         [0, 1],
         [0, 2]]])

很明显,只是取最小和次小


下面是sorted参数
依然用dim=2进行测试

values2 , indices2 = a.topk(2,dim=2,sorted=False)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[9, 7],
         [8, 8],
         [9, 3],
         [7, 4]],

        [[6, 4],
         [9, 9],
         [9, 9],
         [7, 2]],

        [[9, 7],
         [9, 9],
         [8, 8],
         [8, 6]]])
torch.Size([3, 4, 2])
tensor([[[0, 3],
         [0, 2],
         [2, 3],
         [0, 1]],

        [[2, 3],
         [0, 2],
         [3, 0],
         [0, 2]],

        [[0, 1],
         [1, 2],
         [2, 3],
         [3, 1]]])

这里和不sorted=True好像并没有区别,不知道要怎么理解,网上也没找到类似的解释,希望有知道的大佬可以多多指教!!

如有错误请多指正

你可能感兴趣的:(python,pytorch,pytorch,深度学习,python)