tensorflow 的argmax

Tensorflow 的 argmax 接口可以返回一阶以上张量最大值所对应的分量索引。

比如 tensorflow.argmax ( [1,2,3,10,1])  返回 10对应的索引3 。

对于超过一阶的张量,需要指定要搜索的是第几维的元素,这个维是以0开始的。比如

对于a = [[[10.0,25.0,3.0,4.0] , ]  这样一个张量,想找最里层的元素的最大值,可以tensorflow.argmax ( a , 2 ) 来获取。

下面是分别对三阶和一阶张量找最大值的例子。

import tensorflow as tf


def findMaxFromRank3() :
    """
    找出一个三阶张量第三维(索引是2)的最大值
    """
    a =tf.Variable( [[[10.0,25.0,3.0,4.0] , [10.0,251.0,35.0,4.0]] , [[100.0,25.0,3.0,4.0] , [10.0,250.0,3500.0,4.0]] ] )
    b = tf.argmax(  a  , 2 )
    se = tf.Session()
    init = tf.global_variables_initializer()
    se.run( init ) 
    r = se.run( b )
    ar = se.run( a )
    se.close()
    print( r )
    print ( type (r )  )
    print( ar[0][0][r[0][0]] , "," , ar[0][1][ r[0][1] ] )
    print( ar[1][0][r[1][0]] , "," , ar[1][1][ r[1][1] ] )


def findMaxFromRank1() :
    """
    找出一个一阶张量第一维(索引是0)的最大值
    """
    a =tf.Variable( [10.0,250.0,3510.0,4.0 ] )
    b = tf.argmax(  a  , 0 )
    se = tf.Session()
    init = tf.global_variables_initializer()
    se.run( init ) 
    r = se.run( b )
    ar = se.run( a )
    se.close()
    print( r )
    print ( type (r )  )
    print( ar[r] )
    


findMaxFromRank3()
findMaxFromRank1()  

你可能感兴趣的:(python,tensorflow)