sigmoid交叉熵和softmax交叉熵的区别

1、tf.nn.softmax_cross_entropy_with_logits原理

 要求logits与label形状一致, 是先对logits做softmax之后, 再与label做交叉熵运算

loss的输出形状:形状为[batch_size, 1]

import tensorflow as tf

#三个分类,交叉熵通过了sigma求和,因此一个样本对应一个交叉熵,batch_size个样本对应batch_size个交叉熵

#此种情况为多分类问题,但是每个样本只属于一个类别,softmax交叉熵算出来的是一个值

labels = [[0.2,    0.3,    0.5],
               [0.1,   0.6,    0.3]]
logits = [[2,     0.5,  1],
             [0.1,   1,    3]]
logits_scaled = tf.nn.softmax(logits)

result1 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
result2 = -tf.reduce_sum(labels*tf.log(logits_scaled),1)
result3 = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits_scaled)

with tf.Session() as sess:
    print(sess.run(result1))
    print(sess.run(result2))
    print(sess.run(result3))

> [1.4143689 1.6642545]
  [1.4143689 1.6642545]
  [1.1718578 1.1757141]


2、tf.nn.sigmoid_cross_entropy_with

      该函数中labels和logits具有相同的形状,它对于输入的logits先通过sigmoid函数计算,再计算它们的交叉熵,但是它对交叉熵的计算方式进行了优化,使得结果不至于溢出,它适用于每个类别相互独立但互不排斥的情况:例如一幅图可以同时包含一条狗和一只大象,output不是一个数,而是一个batch中每个样本的loss,所以一般配合tf.reduce_mean(loss)使用


loss输出:shape:[batch_size, num_classes], 与label和logits具有相同的形状。

#-*-coding:UTF-8
import tensorflow as tf
import numpy as np
def sigmoid(x):
    return 1.0/(1+np.exp(-x))

#5个样本三分类问题,且一个样本可以同时拥有多类,一个样本会在每个类别上有一个交叉熵
y = np.array([[1,  0,  0], [0,  1,  0], [0,  0,  1], [1,  1,  0], [0,  1,  0]])
logits = np.array([[12,  3,  2], [3,  10,  1], [1,  2,  5], [4,  6.5,  1.2], [3,  6,  1]])
y_pred = sigmoid(logits)
E1 = -y * np.log(y_pred)- (1-y) * np.log(1-y_pred)
#按计算公式计算的结果
print(E1)
sess = tf.Session()
y = np.array(y).astype(np.float64) # labels是float64的数据类型
E2 = sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=y,logits=logits))
print(E2)

E1和E2的结果一样,都为:

[[6.14419348e-06     3.04858735e+00    2.12692801e+00]
 [3.04858735e+00    4.53988992e-05     1.31326169e+00]
 [1.31326169e+00    2.12692801e+00     6.71534849e-03]
 [1.81499279e-02    1.50231016e-03      1.46328247e+00]
 [3.04858735e+00    2.47568514e-03     1.31326169e+00]]

你可能感兴趣的:(机器学习,损失函数)