一、Argmax函数

Argmax函数

代码:

import torch
input = torch.randn(2,2,3,3) # 生成随机张量
print("input",input)
print("input.size",input.size())
output = input.argmax(dim=1) # 沿着维度1,将最大值所在的维度数(是维度1中的0或1)选择出来
print("output", output)
print("output.size", output.size())

结果:

input tensor([[[[ 0.9604, -0.1921, -0.3469],
          [-1.2351,  0.5378, -1.4252],
          [ 0.6342,  0.1409, -0.0997]],

         [[-0.6485, -1.7757, -0.7482],
          [ 1.0768, -1.4759,  0.8779],
          [ 1.2902, -0.1341, -1.1096]]],


        [[[ 0.6024,  0.2546,  1.1767],
          [-1.1221,  1.1944,  0.2713],
          [ 0.0342, -2.0553,  0.4262]],

         [[-0.0980, -0.6603, -0.7870],
          [-1.0530,  0.3212, -0.4535],
          [-1.4209,  0.2561,  1.0194]]]])
input.size torch.Size([2, 2, 3, 3])
output tensor([[[0, 0, 0],
         [1, 0, 1],
         [1, 0, 0]],

        [[0, 0, 0],
         [1, 0, 0],
         [0, 1, 1]]])
output.size torch.Size([2, 3, 3]) # 沿着维度1计算,最后维度1被合并

你可能感兴趣的:(#,常用函数,深度学习,python,人工智能)