Pytorch拾遗(2).max()和.min()方法的详解

min和max使用方法一样,主要以max为主。

#A.min(0): 返回A每一列最小值组成的一维数组;
#A.min(1):返回A每一行最小值组成的一维数组;
#A.max(0):返回A每一列最大值组成的一维数组;
#A.max(1):返回A每一行最大值组成的一维数组;

在pytorch写的代码 特别是强化学习DQN中需要从记忆池中选择transition时,用到了下面的代码

q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)

其实
A.max(0)和max(0,A)是一样的用法

针对上面的q_target 代码,这里我们需要关注两个参数:
第一个是max(1)这个括号里面的1.上面也说清楚了,0表示每一列。1表示每一行。
第二个是max(1)[0],后面这个[]里面的只有两个选项,0和1.其中,0表示选择提取出来的最大值,而1则表示提取的最大值对应的index。

from numpy import array  # 从numpy中引入array,为创建矩阵做准备
import numpy as np
import torch

A1 = array([[1, 2, 3],  # 创建一个4行3列的矩阵
          [4, 5, 6],
          [7, 8, 9],
          [10, 11, 12]])

A = array([[1, 2, 3]])

#B = A.max(1)  # 返回A每一行最小值组成的一维数组;
#print(B)  # 结果 :[1 2 3]

B1=torch.from_numpy(A1)
print('B1',B1)  
'''
B1 tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]], dtype=torch.int32)
'''
B0 = torch.max(B1, 1)[0].data.numpy()  #行
print('B0',B0)  # 
'''
B0 [ 3  6  9 12]
'''
B2 = torch.max(B1, 1)[1].data.numpy()  #行
print('B2',B2)  # 
'''
B2 [2 2 2 2]    最大值对应行的索引
'''

你可能感兴趣的:(Pytorch,pytorch)