这个函数的作用是计算经sigmoid 函数激活之后的交叉熵。
def sigmoid_cross_entropy_with_logits(_sentinel=None, labels=None, logits=None, name=None):
为了描述简洁,我们规定 x = logits(比如一张图),z = targets(分类结果)
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
对其化简
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))
对于x<0时,也就是x是个很小的负数的时候,导致e^{-x}取到无穷大,那么再取log之后无穷大。
导致报错
RuntimeWarning: overflow encountered in exp
为了避免计算exp(-x)时溢出,我们使用以下这种形式表示:
x - x * z + log(1 + exp(-x))
= log(exp(x)) - x * z + log(1 + exp(-x))
= - x * z + log(1 + exp(x))
但实际上这样对于x>0仍然是溢出,综合考虑:
使用
max(x,0)−x∗z+log(1+exp(−abs(x)))
这也是tensorflow中采用的公式。
测试如下代码。
import tensorflow as tf
import numpy as np
def sigmod(x):
return 1.0/(1+np.exp(-x))
logits=np.array([[1.,-810.,20.],[11.,12.,14.],[12.,21.,23.]])
labels=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
y_predict=sigmod(logits)
loss_1=logits*(1-labels)+np.log(1+np.exp(-logits))
print('公式写的函数\n',loss_1)
print('------------------')
print('tensorflow中的函数\n')
with tf.Session() as sess:
print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
print('------------------')
print('优化后的函数\n')
loss_2=np.maximum(logits,0)-logits*labels+np.log(1+np.exp(-np.abs(logits)))
print('优化后的函数\n',loss_2)
out:
公式写的函数
[[ 3.13261688e-01 inf 2.00000000e+01]
[ 1.10000167e+01 6.14419348e-06 1.40000008e+01]
[ 1.20000061e+01 2.10000000e+01 1.02618802e-10]]
------------------
tensorflow中的函数
[[ 3.13261688e-01 0.00000000e+00 2.00000000e+01]
[ 1.10000167e+01 6.14419348e-06 1.40000008e+01]
[ 1.20000061e+01 2.10000000e+01 1.02618796e-10]]
------------------
优化后的函数
优化后的函数
[[ 3.13261688e-01 0.00000000e+00 2.00000000e+01]
[ 1.10000167e+01 6.14419348e-06 1.40000008e+01]
[ 1.20000061e+01 2.10000000e+01 1.02618802e-10]]
ref
https://blog.csdn.net/m0_37393514/article/details/81393819
https://www.cnblogs.com/cloud-ken/p/7435421.html