Tensorflow 的 argmax 接口可以返回一阶以上张量最大值所对应的分量索引。
比如 tensorflow.argmax ( [1,2,3,10,1]) 返回 10对应的索引3 。
对于超过一阶的张量,需要指定要搜索的是第几维的元素,这个维是以0开始的。比如
对于a = [[[10.0,25.0,3.0,4.0] , ] 这样一个张量,想找最里层的元素的最大值,可以tensorflow.argmax ( a , 2 ) 来获取。
下面是分别对三阶和一阶张量找最大值的例子。
import tensorflow as tf
def findMaxFromRank3() :
"""
找出一个三阶张量第三维(索引是2)的最大值
"""
a =tf.Variable( [[[10.0,25.0,3.0,4.0] , [10.0,251.0,35.0,4.0]] , [[100.0,25.0,3.0,4.0] , [10.0,250.0,3500.0,4.0]] ] )
b = tf.argmax( a , 2 )
se = tf.Session()
init = tf.global_variables_initializer()
se.run( init )
r = se.run( b )
ar = se.run( a )
se.close()
print( r )
print ( type (r ) )
print( ar[0][0][r[0][0]] , "," , ar[0][1][ r[0][1] ] )
print( ar[1][0][r[1][0]] , "," , ar[1][1][ r[1][1] ] )
def findMaxFromRank1() :
"""
找出一个一阶张量第一维(索引是0)的最大值
"""
a =tf.Variable( [10.0,250.0,3510.0,4.0 ] )
b = tf.argmax( a , 0 )
se = tf.Session()
init = tf.global_variables_initializer()
se.run( init )
r = se.run( b )
ar = se.run( a )
se.close()
print( r )
print ( type (r ) )
print( ar[r] )
findMaxFromRank3()
findMaxFromRank1()