代码:
import torch
input = torch.randn(2,2,3,3) # 生成随机张量
print("input",input)
print("input.size",input.size())
output = input.argmax(dim=1) # 沿着维度1,将最大值所在的维度数(是维度1中的0或1)选择出来
print("output", output)
print("output.size", output.size())
结果:
input tensor([[[[ 0.9604, -0.1921, -0.3469],
[-1.2351, 0.5378, -1.4252],
[ 0.6342, 0.1409, -0.0997]],
[[-0.6485, -1.7757, -0.7482],
[ 1.0768, -1.4759, 0.8779],
[ 1.2902, -0.1341, -1.1096]]],
[[[ 0.6024, 0.2546, 1.1767],
[-1.1221, 1.1944, 0.2713],
[ 0.0342, -2.0553, 0.4262]],
[[-0.0980, -0.6603, -0.7870],
[-1.0530, 0.3212, -0.4535],
[-1.4209, 0.2561, 1.0194]]]])
input.size torch.Size([2, 2, 3, 3])
output tensor([[[0, 0, 0],
[1, 0, 1],
[1, 0, 0]],
[[0, 0, 0],
[1, 0, 0],
[0, 1, 1]]])
output.size torch.Size([2, 3, 3]) # 沿着维度1计算,最后维度1被合并