PyTorch:torch.max、min、argmax、argmin

目录

1、torch.max

2、torch.argmax

3、torch.min

4、torch.argmin


1、torch.max

函数定义:

torch.max(input, dim, max=None, max_indices=None, keepdim=False) -> (Tensor, LongTensor)

作用:找出给定tensor的指定维度dim上的上的最大值,并返回最大值在该维度上的位置索引

应用举例:

例1——返回相应维度上的最大值

import torch
a = torch.randint(2, 10,(6,4))      # 创建shape为6*4,值为[2,10]的随机整数的tensor
b, max_index = torch.max(a, dim=1)  # 找出a的第1维度(列)上的最大值,返回结果和最大值在相应维度的序号
print('a:', a)
print('b:', b)
print('max_index:', max_index)

'''   输出结果   '''
a: tensor([[9, 6, 6, 5],
           [5, 7, 5, 8],
           [2, 2, 7, 9],
           [8, 9, 3, 5],
           [8, 7, 3, 3],
           [9, 6, 9, 3]])
b: tensor([9, 8, 9, 9, 8, 9])
max_index: tensor([0, 3, 3, 1, 0, 2])

例2——如果max的参数只有一个tensor,则返回该tensor里所有值中的最大值。

import torch
a = torch.randint(2, 10,(6,4))      # 创建shape为6*4,值为[2,10]的随机整数的tensor
b = torch.max(a)  # 找出a的所有元素中的最大值,返回结果
print('a:', a)
print('b:', b)

'''   输出结果   '''
a: tensor([[8, 2, 2, 4],
           [7, 4, 3, 4],
           [4, 4, 3, 4],
           [9, 7, 7, 2],
           [5, 4, 7, 9],
           [4, 5, 7, 5]])
b: tensor(9)

 例3——如果max的参数是两个相同shape的tensor,则返回两tensor对应的最大值的新tensor。

import torch
a = torch.randint(2, 10,(6,4))      # 创建shape为6*4,值为[2,10]的随机整数的tensor
b = torch.randint(2, 10,(6,4))  # 找出a的第1维度(列)上的最大值,返回结果和最大值在相应维度的序号
c = torch.max(a, b)  # 找出a的第1维度(列)上的最大值,返回结果和最大值在相应维度的序号
print('a:', a)
print('b:', b)
print('c:', c)

'''   运行结果   '''
a: tensor([[4, 6, 3, 4],
           [2, 2, 8, 3],
           [6, 2, 6, 8],
           [3, 9, 8, 5],
           [4, 7, 4, 4],
           [9, 5, 8, 3]])
b: tensor([[8, 2, 3, 9],
           [6, 7, 4, 6],
           [8, 9, 3, 6],
           [8, 4, 7, 5],
           [9, 3, 7, 6],
           [4, 7, 9, 6]])
c: tensor([[8, 6, 3, 9],
           [6, 7, 8, 6],
           [8, 9, 6, 8],
           [8, 9, 8, 5],
           [9, 7, 7, 6],
           [9, 7, 9, 6]])

 例4——keepdim=True, 返回的值和位置索引保持原有的维度数。

import torch

a = torch.randint(2, 10,(6,4))      # 创建shape为6*4,值为[2,10]的随机整数的tensor
b, max_index = torch.max(a, dim=1, keepdim=True)  # 找出a的第1维度(列)上的最大值,返回结果和最大值在相应维度的序号
print('a:', a)
print('b:', b)
print('max_index:', max_index)

#=============运行结果===============#
a: tensor([[6, 7, 6, 5],
        [7, 6, 2, 3],
        [2, 3, 7, 3],
        [4, 7, 4, 8],
        [5, 7, 7, 6],
        [5, 4, 5, 6]])
b: tensor([[7],
        [7],
        [7],
        [8],
        [7],
        [6]])
max_index: tensor([[1],
        [0],
        [2],
        [3],
        [2],
        [3]])

 

2、torch.argmax

定义:

torch.argmax(input, dim, keepdim=False) → LongTensor

作用:返回输入张量中指定维度的最大值的索引。

举例说明:

例1——指定维度:返回相应维度最大值的索引

import torch
a = torch.randint(9,(3, 3))
max_index = torch.argmax(a, dim=0)
print('a:\n', a)
print('max_index:\n', max_index)

'''   运行结果   '''
a:
 tensor([[1, 1, 5],
         [2, 8, 1],
         [3, 7, 3]])
max_index:
 tensor([2, 1, 0])

例2——不指定维度,返回整体上最大值的序号

import torch
a = torch.randint(9,(3, 3))
max_index = torch.argmax(a)
print('a:\n', a)
print('max_index:\n', max_index)

'''   运行结果   '''
a:
 tensor([[5, 2, 2],
         [7, 2, 0],
         [8, 0, 6]])
max_index:
 tensor(6)   # 注:tensor在内存中是顺序存储,所以8所在的序号是6

3、torch.min

用法同max。

4、torch.argmin

用法同argmax。

你可能感兴趣的:(Pytorch,pytorch)