首先研究一下dim和k这两个最重要的参数:
import torch
seed = 0
torch.manual_seed(seed)
a = torch.randint(1,10,(3,4,4))
print(a)
tensor([[[9, 1, 3, 7],
[8, 7, 8, 2],
[2, 1, 9, 3],
[7, 4, 2, 3]],
[[1, 1, 6, 4],
[9, 3, 9, 3],
[9, 6, 8, 9],
[7, 1, 2, 1]],
[[9, 7, 2, 7],
[2, 9, 9, 8],
[3, 4, 8, 8],
[2, 6, 5, 8]]])
values , indices = a.topk(2,dim=0)
print(values.shape)
print(values)
print(indices.shape)
print(indices)
torch.Size([2, 4, 4])
tensor([[[9, 7, 6, 7],
[9, 9, 9, 8],
[9, 6, 9, 9],
[7, 6, 5, 8]],
[[9, 1, 3, 7],
[8, 7, 9, 3],
[3, 4, 8, 8],
[7, 4, 2, 3]]])
torch.Size([2, 4, 4])
tensor([[[0, 2, 1, 0],
[1, 2, 1, 2],
[1, 1, 0, 1],
[0, 2, 2, 2]],
[[2, 0, 0, 2],
[0, 0, 2, 1],
[2, 2, 1, 2],
[1, 0, 0, 0]]])
**这里说一下我的理解 **
首先dim=0这个参数表示在某一维取topk,在我的代码中就是取前2个。首先看输出的values和indices的张量形状:(2,4,4)这里可以结合下面两个实验总结出规律,dim取几,输出结果的形状就是:其他维度不变,对应维度变成k。
现在dim=0最后的输出就是要变成(2,4,4)也就是之前第一维中保留两个最大的。
看values的值,讲一下第一行元素[9,7,6,7]是如何得来的:
因为dim=0所以要从第0维来看,将数据分成3份,分别是:
1. [[9, 1, 3, 7],
[8, 7, 8, 2],
[2, 1, 9, 3],
[7, 4, 2, 3]]
2. [[1, 1, 6, 4],
[9, 3, 9, 3],
[9, 6, 8, 9],
[7, 1, 2, 1]]
3. [[9, 7, 2, 7],
[2, 9, 9, 8],
[3, 4, 8, 8],
[2, 6, 5, 8]]
要以这三个tensor为单位进行topk的筛选,首先比较每一个tensor的第一行,因为参数k为2,所以就要找到这3组元素中的最大值和次大值,作为最后的输出。因此最大值就是[9,7,6,7],次大值为:[9,1,3,7]这样就完成了筛选。索引值也就是当前位置处的元素,是来自于这三个元素中的哪一个。我认为把这个看懂后面就可以迎刃而解,大家可以仔细理解一下不太懂的话也没关系,看完后面两个可能这个就懂了。
values1 , indices1 = a.topk(2,dim=1)
print(values1)
print(indices1)
torch.Size([3, 2, 4])
tensor([[[9, 7, 9, 7],
[8, 4, 8, 3]],
[[9, 6, 9, 9],
[9, 3, 8, 4]],
[[9, 9, 9, 8],
[3, 7, 8, 8]]])
torch.Size([3, 2, 4])
tensor([[[0, 1, 2, 0],
[1, 3, 1, 2]],
[[1, 2, 1, 2],
[2, 1, 2, 0]],
[[0, 1, 1, 2],
[2, 0, 2, 3]]])
这个例子是dim=1时,类比于dim=0的情况。这里是对第一维进行筛选操作。需要注意的是这里第0维的三个元素是分开操作的。这里我提供一种我自己的理解思路大家借鉴。首先还是按照第0维将tensor分为3块
1. [[9, 1, 3, 7],
[8, 7, 8, 2],
[2, 1, 9, 3],
[7, 4, 2, 3]]
2. [[1, 1, 6, 4],
[9, 3, 9, 3],
[9, 6, 8, 9],
[7, 1, 2, 1]]
3. [[9, 7, 2, 7],
[2, 9, 9, 8],
[3, 4, 8, 8],
[2, 6, 5, 8]]
这里每一块中的第0维就是总体tensor的第一维,从第0维来看就是4个14的向量,因此就是对这4向量取最大值和次大值。也就是在这个44的张量中选出对应位置的最大值和次大值。例如第一块中筛选出的结果就是[9,7,9,7]和[8,4,8,3]其他同理,索引值表示当前位置处的值是来自哪一个向量。
values2 , indices2 = a.topk(2,dim=2)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[9, 7],
[8, 8],
[9, 3],
[7, 4]],
[[6, 4],
[9, 9],
[9, 9],
[7, 2]],
[[9, 7],
[9, 9],
[8, 8],
[8, 6]]])
torch.Size([3, 4, 2])
tensor([[[0, 3],
[0, 2],
[2, 3],
[0, 1]],
[[2, 3],
[0, 2],
[3, 0],
[0, 2]],
[[0, 1],
[1, 2],
[2, 3],
[3, 1]]])
类比前两种情况的思考方式,这里的操作就是对整个张量最内层做的操作,也就是整体张量形状(3,4,4)中的4这个4就是最内层每一个一维向量中的4个元素,取对应的最大值和次大值,应该也容易理解。大家可以对比着三种情况的输入输出加以理解。
另外,k参数默认是最后一维
然后研究一下lagest参数:
直接用最后一维
values2 , indices2 = a.topk(2,dim=2,largest=False)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[1, 3],
[2, 7],
[1, 2],
[2, 3]],
[[1, 1],
[3, 3],
[6, 8],
[1, 1]],
[[2, 7],
[2, 8],
[3, 4],
[2, 5]]])
torch.Size([3, 4, 2])
tensor([[[1, 2],
[3, 1],
[1, 0],
[2, 3]],
[[1, 0],
[1, 3],
[1, 2],
[1, 3]],
[[2, 3],
[0, 3],
[0, 1],
[0, 2]]])
很明显,只是取最小和次小
下面是sorted参数
依然用dim=2进行测试
values2 , indices2 = a.topk(2,dim=2,sorted=False)
print(values2.shape)
print(values2)
print(indices2.shape)
print(indices2)
torch.Size([3, 4, 2])
tensor([[[9, 7],
[8, 8],
[9, 3],
[7, 4]],
[[6, 4],
[9, 9],
[9, 9],
[7, 2]],
[[9, 7],
[9, 9],
[8, 8],
[8, 6]]])
torch.Size([3, 4, 2])
tensor([[[0, 3],
[0, 2],
[2, 3],
[0, 1]],
[[2, 3],
[0, 2],
[3, 0],
[0, 2]],
[[0, 1],
[1, 2],
[2, 3],
[3, 1]]])
这里和不sorted=True好像并没有区别,不知道要怎么理解,网上也没找到类似的解释,希望有知道的大佬可以多多指教!!