tensor 最大值的下标的求法

import tensorflow as tf
import numpy as np

a=tf.constant([1,0,-1,12,5],dtype=tf.float32)
# b=tf.argmax(input(a),axis=1) 出错
flag=0

with tf.Session() as sess:
    print(sess.run(a))
    b = a.eval()

    for i in range(4):
        c=b[i]
        if c>=b[i]:
            flag=i
    print(flag)


    b = tf.constant([1, 0, -1, 12, 5], dtype=tf.float32)
    b=b.eval().tolist() #b转为list,调用函数index(max())求最大值下标
    dd = b.index(max(b))
    print(dd)

ar=np.array([[1,2,3,4],[5,6,7,8]])
print(ar)
print(np.argmax(ar,axis=1))

 

你可能感兴趣的:(python)