tf.one_hot 在求交叉熵是会用到,要对目标值(标签值)进行one_hot编码。相当于将多个数值联合放在一起作为多个相同类型的向量,可用于表示各自的概率分布。比如在求交叉熵的时候:
交叉熵公式:
yi我们知道是softmax后的结果,也就是某个样本是所有类别的每一个类别的概率,yi’ 是真实的结果,也是一个概率,那应该是多少呢?刚刚说了one_hot相当于将多个数值联合放在一起作为多个相同类型的向量,可用于表示各自的概率分布;因此我计算交叉熵是这样的计算的:
如上图,一个手写体识别的例子,手写体一个有10个类别(0-9),也就是0,1,2,3,45,6,7,8,9。若识别某个样本,这个样本的真实值是1,那么one_hot编码后是[0 1 0 0 0 0 0 0 0 0], softmax后的结果是[0.0 0.7 0.1 …0.02],交叉熵是0log(0.0) + 1log(0.7) + 0*log(0.1)…
API参数说明:
tf.one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
Returns a one-hot tensor.
indices表示输入的多个数值,通常是矩阵形式;depth表示输出的尺寸
由于one-hot类型数据长度为depth位,其中只用一位数字表示原输入数据,这里的on_value就是这个数字,默认值为1,one-hot数据的其他位用off_value表示,默认值为0。
tf.one_hot()函数规定输入的元素indices从0开始,最大的元素值不能超过(depth - 1),因此能够表示depth个单位的输入。若输入的元素值超出范围,输出的编码均为 [0, 0 … 0, 0]。
indices = 0 对应的输出是[1, 0 … 0, 0], indices = 1 对应的输出是[0, 1 … 0, 0], 依次类推,最大可能值的输出是[0, 0 … 0, 1]。
例子:
1、当depth =0
import tensorflow as tf
lable = [0, 1, 2, 1]
depth = 0
lable_ont_hot = tf.one_hot(lable, depth)
print(lable_ont_hot)
with tf.Session() as sess:
print(sess.run(lable_ont_hot))
结果:
C:\Users\FCX-PC\Envs\tensorflow\Scripts\python.exe J:/测试.py
Tensor("one_hot:0", shape=(4, 0), dtype=float32)
[]
Process finished with exit code 0
2、当depth =1时:
结果:
C:\Users\FCX-PC\Envs\tensorflow\Scripts\python.exe J:/测试.py
Tensor("one_hot:0", shape=(4, 1), dtype=float32)
[[1.]
[0.]
[0.]
[0.]]
Process finished with exit code 0
3、当depth =2时
结果:
C:\Users\FCX-PC\Envs\tensorflow\Scripts\python.exe J:/测试.py
Tensor("one_hot:0", shape=(4, 2), dtype=float32)
[[1. 0.]
[0. 1.]
[0. 0.]
[0. 1.]]
Process finished with exit code 0
4、当depth =3时
结果:
C:\Users\FCX-PC\Envs\tensorflow\Scripts\python.exe J:/测试.py
Tensor("one_hot:0", shape=(4, 3), dtype=float32)
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]
[0. 1. 0.]]
Process finished with exit code 0
5、当depth =4时
结果
C:\Users\FCX-PC\Envs\tensorflow\Scripts\python.exe J:/测试.py
Tensor("one_hot:0", shape=(4, 4), dtype=float32)
[[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]
[0. 1. 0. 0.]]
Process finished with exit code 0
通过上述一系列的depth导致one_hot不同的结果,我们也可以看到,depth的大小就是编码后数组的列数。
可以看到原来的 lable = [0, 1, 2, 1], 当depth = 3是变成
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]
[0. 1. 0.]]
大小是(4,3),怎么变成的呢, 其实就是1的不同位置代表原来的值得大小,只是这个位置是从0开始的;比如,原来label第一个元素是0,表示为 [1 0 0],因为1在第0位,同理,元素2,在第一位也就是[0 0 1]
所以,当depth不大时,(depth-1)小于label中最大元素 2是比如depth =2时,结果是
[[1. 0.]
[0. 1.]
[0. 0.]
[0. 1.]]
也就是说,表示不不出来2,因为2是【0 0 1】必须3个位置才可以。