np.argmax(a, axis=None, out=None)
tf.argmax(input, axis=None, name=None, dimension=None, output_type=dtypes.int64):
np.argmax()与tf.agrmax()函数用法类似,用于寻找每一行或者每一列中的最大值的索引值,axis的值代表行或列,分别表示为axis=0 按列寻找、axis=1 按行寻找;
eg:
>>> test = np.array([
... [2, 5, 6],
... [8, 12, 1],
... [3, 10, 2],
... [4, 6, 9]])
>>> np.argmax(test,0)
array([1, 1, 3], dtype=int64)
>>> np.argmax(test,1)
array([2, 1, 1, 2], dtype=int64)
当只输入input值时,axis值默认为None,这时在使用argmax(input)时,np.argmax()函数会遍历每一行寻找最大值的索引。
eg:
>>> test = np.array([
... [2, 5, 6],
... [8, 12, 1],
... [3, 10, 2],
... [4, 6, 9]])
>>> np.argmax(test)
4
而tf.argmax()会默认按照axis=0,即按照列去寻找最大索引值
eg:
>>> a = tf.argmax(test)
>>> tf.Session().run(a)
array([1, 1, 3], dtype=int64)
而当input中只包含一行数组时,axis只能取0,即按列寻找最大索引值。若axis=1,tensorflow和numpy均会报错
eg:
>>> test2 = np.array([1,3,5,2])
>>> np.argmax(test2,1)
Traceback (most recent call last):
File "", line 1, in
File "<__array_function__ internals>", line 6, in argmax
File "D:\Develop\anaconda3\envs\TF115_py36\lib\site-packages\numpy\core\fromnumeric.py", line 1188, in argmax
return _wrapfunc(a, 'argmax', axis=axis, out=out)
File "D:\Develop\anaconda3\envs\TF115_py36\lib\site-packages\numpy\core\fromnumeric.py", line 58, in _wrapfunc
return bound(*args, **kwds)
numpy.AxisError: axis 1 is out of bounds for array of dimension 1