Tensorflow——tf.argmax()解析

简介

tf.argmax就是返回最大的那个数值所在的下标。

定义如下:
def argmax(self, axis=None, fill_value=None, out=None):

tf.argmax()的参数:比如,tf.argmax(array, 1)和tf.argmax(array, 0)有啥区别呢?

例子

test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])
np.argmax(test, 0)   #输出:array([3, 3, 1]
np.argmax(test, 1)   #输出:array([2, 2, 0, 0]

tf.argmax(array, 1)

等于1的时候,比较范围缩小了,只会比较每个数组内的数的大小,结果也会根据有几个数组,产生几个结果。

test[0] = array([1, 2, 3])  #2
test[1] = array([2, 3, 4])  #2
test[2] = array([5, 4, 3])  #0
test[3] = array([8, 7, 2])  #0

tf.argmax(array, 0)

你就这么想,0是最大的范围,所有的数组都要进行比较,只是比较的是这些数组相同位置上的数:

test[0] = array([1, 2, 3])
test[1] = array([2, 3, 4])
test[2] = array([5, 4, 3])
test[3] = array([8, 7, 2])
# output   :    [3, 3, 1]  

你可能感兴趣的:(Tensorflow——tf.argmax()解析)