TensorFlow 深度学习损失函数tf.nn.softmax_cross_entropy_with_logits

在学习深度学习时,遇到一个疑惑,不同的训练场景使用的损失函数有所不同:

有的训练场景A使用先softmax再交叉熵:

#y为预测值;y_为标签值
y=tf.nn.softmax(logits)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), axis=1))

有的场景B却使用:

# logits为预测的输出;y_为标签值
cross_entropy2=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))

搞得我一头懵,这两种损失函数有什么不同?这两种损失函数适用于什么样的场景?带着这些疑问查了很多资料,得知,“tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_)”实际是先对预测的输出值“logits”进行softmax操作,之后再对softmax的输出结果和标签值“y_”进行交叉熵操作,说白了“场景A”和“场景B”所使用的损失函数其实是一样的,只不过后者是前者的综合。

为了验证上述结论,参照别人的实验DEMO自己做了一个,如下:

import tensorflow as tf

#our NN's output
logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0],[7.0,4.0,5.0],[7.0,88.0,5.0]])
#step1:do softmax
y=tf.nn.softmax(logits)
#true label
y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0],[1.0,0.0,0.0],[1.0,0.0,0.0]])
#step2:do cross_entropy
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y), axis=1))
#do cross_entropy just one step
cross_entropy2=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))#dont forget tf.reduce_sum()!!

with tf.Session() as sess:
    softmax=sess.run(y)
    c_e = sess.run(cross_entropy)
    c_e2 = sess.run(cross_entropy2)
    print("step1:softmax result=")
    print(softmax)
    print("step2:cross_entropy result=")
    print(c_e)
    print("Function(softmax_cross_entropy_with_logits) result=")
    print(c_e2)

输出:

step1:softmax result=
[[9.0030573e-02 2.4472848e-01 6.6524094e-01]
 [9.0030573e-02 2.4472848e-01 6.6524094e-01]
 [9.0030573e-02 2.4472848e-01 6.6524094e-01]
 [8.4379470e-01 4.2010065e-02 1.1419519e-01]
 [6.6396770e-36 1.0000000e+00 8.9858262e-37]]
step2:cross_entropy result=
16.478533
Function(softmax_cross_entropy_with_logits) result=
16.478533

从上述实验结果可知上述结论是正确的。

参考:

  1. https://blog.csdn.net/yhily2008/article/details/80262321
  2. https://blog.csdn.net/qq_35203425/article/details/79773459
  3. https://www.jianshu.com/p/fa91da9ec643
  4. https://blog.csdn.net/mao_xiao_feng/article/details/53382790
  5. https://www.imooc.com/article/23674

你可能感兴趣的:(TensorFlow,深度学习)