pytorch 只保留tensor的最大值或最小值,其他位置置零

如下,x是输入张量,dim指定维度,max可以替换成min


import torch

if __name__ == '__main__':
    
    x = torch.randn([1, 3, 4, 4]).cuda()

    mask = (x == x.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
    result = torch.mul(mask, x)

    print(x)
    print(mask)
    print(result)

输出效果:

tensor([[[[-0.8807,  0.1029,  0.0184,  1.2695],
          [-0.0934,  1.0650, -0.2927,  0.0049],
          [ 0.2338, -1.8663,  1.2763,  0.7248],
          [-1.5138,  0.6834,  0.1463,  0.0650]],

         [[ 0.5020,  1.6078, -0.0104,  1.2042],
          [ 1.8859, -0.4682, -0.1177,  0.5197],
          [ 1.7649,  0.4585,  0.6002,  0.3350],
          [-1.1384, -0.0325,  0.8490,  0.6080]],

         [[-0.5618,  0.5388, -0.0572, -0.7240],
          [-0.3458,  1.3494, -0.0603, -1.1562],
          [-0.3652,  1.1885,  1.6293,  0.4134],
          [ 1.3009,  1.2027, -0.8711,  1.3321]]]], device='cuda:0')
tensor([[[[0, 0, 1, 1],
          [0, 0, 0, 0],
          [0, 0, 0, 1],
          [0, 0, 0, 0]],

         [[1, 1, 0, 0],
          [1, 0, 0, 1],
          [1, 0, 0, 0],
          [0, 0, 1, 0]],

         [[0, 0, 0, 0],
          [0, 1, 1, 0],
          [0, 1, 1, 0],
          [1, 1, 0, 1]]]], device='cuda:0', dtype=torch.int32)
tensor([[[[-0.0000,  0.0000,  0.0184,  1.2695],
          [-0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.7248],
          [-0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.5020,  1.6078, -0.0000,  0.0000],
          [ 1.8859, -0.0000, -0.0000,  0.5197],
          [ 1.7649,  0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.8490,  0.0000]],

         [[-0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  1.3494, -0.0603, -0.0000],
          [-0.0000,  1.1885,  1.6293,  0.0000],
          [ 1.3009,  1.2027, -0.0000,  1.3321]]]], device='cuda:0')

Process finished with exit code 0

参考链接:
https://discuss.pytorch.org/t/keep-the-max-value-of-the-array-and-0-the-others/14480/8

你可能感兴趣的:(Python,pytorch,其他,深度学习)