tf的softmax交叉熵计算中的logits的含义

tf.nn.softmax_cross_entropy_with_logits函数是TensorFlow中常用的求交叉熵的函数。其中函数名中的“logits”是个什么意思呢?它时不时地困惑初学者,下面我们就讨论一下。

 
  1. tf.nn.softmax_cross_entropy_with_logits(

  2. _sentinel=None,

  3. labels=None,

  4. logits=None,

  5. dim=-1,

  6. name=None

  7. )

这个函数的功能就是计算labels和logits之间的交叉熵(cross entropy)。

第一个参数基本不用。此处不说明。
第二个参数label的含义就是一个分类标签,所不同的是,这个label是分类的概率,比如说[0.2,0.3,0.5],labels的每一行必须是一个概率分布。

现在来说明第三个参数logits,logit本身就是是一种函数,它把某个概率p从[0,1]映射到[-inf,+inf](即正负无穷区间)。这个函数的形式化描述为:logit=ln(p/(1-p))。
我们可以把logist理解为原生态的、未经缩放的,可视为一种未归一化的log 概率,如是[4, 1, -2]

于是,Softmax的工作则是,它把一个系列数从[-inf, +inf] 映射到[0,1],除此之外,它还把所有参与映射的值累计之和等于1,变成诸如[0.95, 0.05, 0]的概率向量。这样一来,经过Softmax加工的数据可以当做概率来用。

也就是说,logits是作为softmax的输入。经过softmax的加工,就变成“归一化”的概率(设为q),然后和labels代表的概率分布(设为q),于是,整个函数的功能就是前面的计算labels(概率分布p)和logits(概率分布q)之间的交叉熵

下面我们列举一个案例说明:

 
  1. #!/usr/bin/env python3

  2. # -*- coding: utf-8 -*-

  3. """

  4. Created on Thu May 10 08:32:59 2018

  5.  
  6. @author: yhilly

  7. """

  8.  
  9. import tensorflow as tf

  10.  
  11. labels = [[0.2,0.3,0.5],

  12. [0.1,0.6,0.3]]

  13. logits = [[4,1,-2],

  14. [0.1,1,3]]

  15.  
  16. logits_scaled = tf.nn.softmax(logits)

  17. result = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)

  18.  
  19. with tf.Session() as sess:

  20. print (sess.run(logits_scaled))

  21. print (sess.run(result))

运行结果:

 
  1. [[0.95033026 0.04731416 0.00235563]

  2. [0.04622407 0.11369288 0.84008306]]

  3. [3.9509459 1.6642545]

需要注意的是:


(1)如果labels的每一行是one-hot表示,也就是只有一个地方为1(或者说100%),其他地方为0(或者说0%),还可以使用tf.sparse_softmax_cross_entropy_with_logits()。之所以用100%和0%描述,就是让它看起来像一个概率分布。
(2)tf.nn.softmax_cross_entropy_with_logits()函数已经过时 (deprecated),它在TensorFlow未来的版本中将被去除。取而代之的是

tf.nn.softmax_cross_entropy_with_logits_v2()。

 (3)参数labels,logits必须有相同的形状 [batch_size, num_classes] 和相同的类型(float16, float32, float64)中的一种,否则交叉熵无法计算。

你可能感兴趣的:(机器学习模型)