pytorch的topk能够返回最大的k个值,现在假设有一个[2,3,4]的权重矩阵,如果我们需要在第三个维度找出最大的两个值,(并保持权重矩阵的维度不变,且最大值的位置也不变),topk就不是很好用了,以下代码能解决这个问题:
import torch
import numpy as np
if __name__ == "__main__":
x=torch.tensor(np.arange(1,25)).reshape(2,3,4)
print(x)
# k=2表示选择两个最大值
a,_=x.topk(k=2,dim=2)
# 要加上values,否则会得到一个包含values和indexs的对象
a_min=torch.min(a,dim=-1).values
# repeat里的4和x的最后一维相同
a_min=a_min.unsqueeze(-1).repeat(1,1,4)
ge=torch.ge(x,a_min)
# 设置zero变量,方便后面的where操作
zero=torch.zeros_like(x)
result=torch.where(ge,x,zero)
print(result)
输出是:
# 原矩阵
tensor([[[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]],
[[13, 14, 15, 16],
[17, 18, 19, 20],
[21, 22, 23, 24]]], dtype=torch.int32)
# 每个维度只保留两个最大值
tensor([[[ 0, 0, 3, 4],
[ 0, 0, 7, 8],
[ 0, 0, 11, 12]],
[[ 0, 0, 15, 16],
[ 0, 0, 19, 20],
[ 0, 0, 23, 24]]], dtype=torch.int32)
topk的输出有两个,其他地方可能会派上用场:
a,b=x.topk(k=2,dim=2)
print(a)
print(b)
# 输出
tensor([[[ 4, 3],
[ 8, 7],
[12, 11]],
[[16, 15],
[20, 19],
[24, 23]]], dtype=torch.int32)
tensor([[[3, 2],
[3, 2],
[3, 2]],
[[3, 2],
[3, 2],
[3, 2]]])
Process finished with exit code 0