sigmoid_cross_entropy_with_logits交叉熵损失简介及测试

def

这个函数的作用是计算经sigmoid 函数激活之后的交叉熵。

def sigmoid_cross_entropy_with_logits(_sentinel=None,  labels=None, logits=None,  name=None):

计算公式:

为了描述简洁,我们规定 x = logits(比如一张图),z = targets(分类结果)

z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))

对其化简

z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
      = (1 - z) * x + log(1 + exp(-x))
      = x - x * z + log(1 + exp(-x))

对于x<0时,也就是x是个很小的负数的时候,导致e^{-x}取到无穷大,那么再取log之后无穷大。
导致报错

RuntimeWarning: overflow encountered in exp

为了避免计算exp(-x)时溢出,我们使用以下这种形式表示:

 x - x * z + log(1 + exp(-x))
      = log(exp(x)) - x * z + log(1 + exp(-x))
      = - x * z + log(1 + exp(x))

但实际上这样对于x>0仍然是溢出,综合考虑:
使用

max(x,0)−x∗z+log(1+exp(−abs(x)))

这也是tensorflow中采用的公式。
测试如下代码。

code


import tensorflow as tf
import numpy as np

def sigmod(x):
    return 1.0/(1+np.exp(-x))

logits=np.array([[1.,-810.,20.],[11.,12.,14.],[12.,21.,23.]])
labels=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])

y_predict=sigmod(logits)

loss_1=logits*(1-labels)+np.log(1+np.exp(-logits))
print('公式写的函数\n',loss_1)
print('------------------')
print('tensorflow中的函数\n')

with tf.Session() as sess:
    print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
print('------------------')
print('优化后的函数\n')

loss_2=np.maximum(logits,0)-logits*labels+np.log(1+np.exp(-np.abs(logits)))
print('优化后的函数\n',loss_2)

out:

公式写的函数
 [[  3.13261688e-01              inf   2.00000000e+01]
 [  1.10000167e+01   6.14419348e-06   1.40000008e+01]
 [  1.20000061e+01   2.10000000e+01   1.02618802e-10]]
------------------
tensorflow中的函数

[[  3.13261688e-01   0.00000000e+00   2.00000000e+01]
 [  1.10000167e+01   6.14419348e-06   1.40000008e+01]
 [  1.20000061e+01   2.10000000e+01   1.02618796e-10]]
------------------
优化后的函数

优化后的函数
 [[  3.13261688e-01   0.00000000e+00   2.00000000e+01]
 [  1.10000167e+01   6.14419348e-06   1.40000008e+01]
 [  1.20000061e+01   2.10000000e+01   1.02618802e-10]]

ref
https://blog.csdn.net/m0_37393514/article/details/81393819
https://www.cnblogs.com/cloud-ken/p/7435421.html

你可能感兴趣的:(tensorflow)