tf.equal(tf.argmax(y,1),tf.argmax(y_,1))和tf.reduce_mean(tf.cast(correct_prediction,tf.float32))浅谈

在评估模型时候,我们首先预测类标,tf.argmax是一个很有用的函数,其返回值给定Tensor某一坐标轴上最高得分的索引值
例如:
tf.argmax(y,1)返回的是模型,每一输入数据最大可能的预测类标。
tf.argmax(y_,1)返回的是真实的类标。
最后我们用tf.equal函数检查预测类标与真实类标是否相同。

correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))

返回值correct_prediction 是一个布尔值链表。计算模型的精度还要计算链表的均值。
例如:
[True,False,True,True]可以用[1,0,1,1]表示,精度为0.75。
0.75 = 3/4

accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

tf.cast()函数的作用是执行 tensorflow 中张量数据类型转换
此处是将布尔型转换为float32

如有错误,请您指正,谢谢。

你可能感兴趣的:(机器学习和深度学习乱搞)