pytorch topk 保持维度和位置 置零

pytorch的topk能够返回最大的k个值,现在假设有一个[2,3,4]的权重矩阵,如果我们需要在第三个维度找出最大的两个值,(并保持权重矩阵的维度不变,且最大值的位置也不变),topk就不是很好用了,以下代码能解决这个问题:

import torch
import numpy as np
if __name__ == "__main__":
    x=torch.tensor(np.arange(1,25)).reshape(2,3,4)
    print(x)
    # k=2表示选择两个最大值
    a,_=x.topk(k=2,dim=2)
    # 要加上values,否则会得到一个包含values和indexs的对象
    a_min=torch.min(a,dim=-1).values
    # repeat里的4和x的最后一维相同
    a_min=a_min.unsqueeze(-1).repeat(1,1,4)
    ge=torch.ge(x,a_min)
    # 设置zero变量,方便后面的where操作
    zero=torch.zeros_like(x)
    result=torch.where(ge,x,zero)
    print(result)

输出是:

# 原矩阵
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]], dtype=torch.int32)
# 每个维度只保留两个最大值
tensor([[[ 0,  0,  3,  4],
         [ 0,  0,  7,  8],
         [ 0,  0, 11, 12]],

        [[ 0,  0, 15, 16],
         [ 0,  0, 19, 20],
         [ 0,  0, 23, 24]]], dtype=torch.int32)

topk的输出有两个,其他地方可能会派上用场:

    a,b=x.topk(k=2,dim=2)
    print(a)
    print(b)
    
# 输出    
tensor([[[ 4,  3],
         [ 8,  7],
         [12, 11]],

        [[16, 15],
         [20, 19],
         [24, 23]]], dtype=torch.int32)
tensor([[[3, 2],
         [3, 2],
         [3, 2]],

        [[3, 2],
         [3, 2],
         [3, 2]]])

Process finished with exit code 0


你可能感兴趣的:(python,深度学习,深度学习)