pytorch中torch.max()的用法

用法:
(max, max_indices) = torch.max(input, dim, keepdim=False)

输入: input是输入的tensor,dim指定在哪一维度求最大值,keepdim表示是否需要保持输出的维度与输入一样,keepdim=True表示输出和输入的维度一样,keepdim=False表示输出的维度被压缩了,也就是输出会比输入低一个维度。
输出: max为最大值的结果,max_indices为相应最大值的索引。

对dim的详细讨论:

  • 没有传递dim值:在输入的整个tensor中求最大值
  • 如果 0 ≤ d i m ≤ x . d i m ( ) − 1 0≤dim≤x.dim()-1 0dimx.dim()1 ,x.dim()代表输入数据x的维度, 表示在第几维度上求最大值
  • 如果 d i m < 0 dim<0 dim<0 , 相当于dim+x.dim(),也就是说dim=-1相当于传入的是x.dim()-1

dim=0是沿着最粗数据粒度的方向进行操作,x.dim()-1是按照最细粒度的方向进行操作

代码示例:

1.keepdim
pytorch中torch.max()的用法_第1张图片
_的作用:
有时用作临时或无意义变量的名称。也表示python REPL中最近一个表达式的结果。
pytorch中torch.max()的用法_第2张图片

2.dim

x是二维数组:

dim=0表示第二维度固定,在第一维度上求,第一维度是行索引变化找最值,也就相当于取每一列的最大值。
pytorch中torch.max()的用法_第3张图片
dim=1表示第一维度固定,在第二维度上求,也就是列索引在变化找最值,相当于取每一行的最大值
pytorch中torch.max()的用法_第4张图片
x是三维数组:

pytorch中torch.max()的用法_第5张图片
pytorch中torch.max()的用法_第6张图片
dim=0表示固定第二和第三维,变化第一维求最值,可以把三维数组看成一个每个元素是个二维数组的一维数组,因此变化第一维相当于在每个二维数组之间对应位置求最值,因此得到的输出max的维度是2 * 4。
dim=1表示固定第一维和第三维,变化第二维求最值,第二维实际上是每个二维数组的第一维,也就是变化行索引而列索引不变,也就是相当于二维dim=0的情形,因此输出max的维度是3 * 4。
dim=2表示固定第一和第二维,变化第三维求最值,第三维实际上是每个二维数组的第二维,也就是变化列索引而行索引不变,也就是相当于二维dim=1的情形,因此输出max的维度是3*2。

pytorch中torch.max()的用法_第7张图片
对于三维的输入,dim=-1相当于dim=3-1=2,dim=-2相当于dim=3-2=1, 结果如上,可与前面dim=1和2的时候进行对比。

你可能感兴趣的:(pytorch,pytorch,深度学习)