torch.argmax方法详解

torch.argmax方法详解

torch.argmax(x, dim),其中x为张量,dim控制比较的维度,返回最大值的索引。
1.当dim=0时

import torch
x = torch.rand(2, 3,2)
print(x)
torch.argmax(x, dim=0)

当dim=0时,表示后两个维度进行比较,得到结果如下图:
torch.argmax方法详解_第1张图片
比较过程为:输出结果的张量y的大小为去掉需比较维度dim后的大小,即3x2。然后依次确定这6个值,首先,对x[:,0,0]中的值进行比较,

torch.argmax方法详解_第2张图片

取较大值的索引值输出结果的值,0.6718>0.6402,即y[0,0]=1;接着,对x[:,0,1]进行比较,
在这里插入图片描述
取较大值的索引值输出结果的值,即y[0,1]=0;以此类推,直到将所有的比较完成。

当dim为1,或2的情况也类似,结果如下:
2.当dim=1
torch.argmax方法详解_第3张图片
3.当dim=2
torch.argmax方法详解_第4张图片

你可能感兴趣的:(pytorch,pytorch)