tensorflow笔记 cross entropy loss

交叉熵损失函数是模型中非常常见的一种损失函数,tensorflow中有一个计算交叉熵的函数:tf.nn.sigmoid_cross_entropy_with_logits,也可以调用keras中的函数: tf.keras.backend.binary_crossentropy,需要注意的是两者的输入有一些不同。


先来看看tf自带的sigmoid_cross_entropy_with_logits:

tf.nn.sigmoid_cross_entropy_with_logits(
    _sentinel=None,
    labels=None,
    logits=None,
    name=None
)

sigmoid_cross_entropy_with_logits()需要两个参数,神经网络最后一层的输出logits和真实值labels。内部会经过一次sigmoid再计算cross entropy loss,计算方式如下所示:

令x = logits, z = labels
Loss = - z * log(sigmoid(x)) - (1 - z) * log(1 - sigmoid(x))
= - z * log(1 / (1 + exp(-x))) - (1 - z) * log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

即后面代码中的prob_error2式


tf.keras.backend.binary_crossentropy与sigmoid_cross_entropy_with_logits输入有一些不一样,因为keras是已经内部封装好的函数,所以要求的输入是神经网络经过sigmoid后的输出,binary_crossentropy在内部会先将输入转化为logits,然后再调用tf.nn.sigmoid_cross_entropy_with_logits计算交叉熵。


下面做一个简单的验证,注意两个函数输入的不同

import numpy as np
import tensorflow as tf

def sigmoid(x):
    return 1.0/(1+np.exp(-x))

labels=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
logits=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
label = tf.convert_to_tensor(labels, np.float32)
logit = tf.convert_to_tensor(logits, np.float32)
y_pred=sigmoid(logits)
y_preds = tf.convert_to_tensor(y_pred, np.float32)
prob_error1=-labels*np.log(y_pred)-(1-labels)*np.log(1-y_pred)
prob_error2=-logits*labels+np.log(1+np.exp(logits))
print(prob_error1)
print(prob_error2)


print(".............")
with tf.Session() as sess:
    #print(sess.run(label))
    print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
    #print('='*20)
    #print(sess.run(label))
    print(sess.run(tf.keras.backend.binary_crossentropy(label,y_preds)))
    #print(sess.run(prob_error2))

'''
[[0.31326169 0.69314718 0.69314718]
 [0.69314718 0.31326169 0.69314718]
 [0.69314718 0.69314718 0.31326169]]
[[0.31326169 0.69314718 0.69314718]
 [0.69314718 0.31326169 0.69314718]
 [0.69314718 0.69314718 0.31326169]]
.............
[[0.31326169 0.69314718 0.69314718]
 [0.69314718 0.31326169 0.69314718]
 [0.69314718 0.69314718 0.31326169]]
[[0.3132617 0.6931472 0.6931472]
 [0.6931472 0.3132617 0.6931472]
 [0.6931472 0.6931472 0.3132617]]
'''

你可能感兴趣的:(代码笔记,tensorflow,交叉熵)