tf.keras.losses.SparseCategoricalCrossentropy用于计算多分类问题的交叉熵。标签应为一个整数,而不是one-hot编码形式。
tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=False,
reduction=losses_utils.ReductionV2.AUTO,
name='sparse_categorical_crossentropy'
)
from_logits:默认为False. 当from_logits=true时,会先对预测值进行Softmax概率化。
from tensorflow.keras.losses import SparseCategoricalCrossentropy
'''
每一个样本都有一个标签,表示它的真实类别。在深度学习中,通常是分批进行训练。
对于一个N分类问题,样本的标签只可能是0、1、2 ... N-1
则对于一个3分类问题,样本的标签只可能是0或1或2。
当batch_size为5时,这一个batch的标签的形状为[batch_size],即shape为[5]
'''
# 一个batch(batch_size=5)的标签
label = [1, 2, 0, 1, 0]
'''
对于一个3分类问题,当训练时batch_size为5,
则深度网络对每一个batch的预测值的形状为[batch_size, classes],即shape为[5, 3]
以深度网络对第一个样本的预测值[0.5, 0.8, -1.2]为例,
经过Softmax层后,得到[0.3949, 0.5330, 0.0721],表示深度网络认为第一个样本属于0,1,2这三类的概率分别是0.3949,0.5330, 0.0721
'''
predict = [[ 0.5, 0.8, -1.2],
[-0.2, 1.8, 0.5],
[ 0.3, 0.2, 0.7],
[ 0.6,-0.8, -0.4],
[-0.4, 0.2, 0.8]]
# 当from_logits=true时,会先对predict进行Softmax运算,就无需在网络的最后添加Softmax层
loss_func = SparseCategoricalCrossentropy(from_logits=True)
loss = loss_func(label, predict)
print(loss) # tf.Tensor(1.4376585, shape=(), dtype=float32)
tf.keras.losses.CategoricalCrossentropy用于计算多分类问题的交叉熵。标签应为one-hot编码形式。例如对于一个3分类问题,若属于第0类,标签应为[1, 0, 0]; 若属于第1类,标签应为[0, 1, 0]; 若属于第2类,标签应为[0, 0, 1]
from tensorflow.keras.losses import CategoricalCrossentropy
'''
对于一个3分类问题,若属于第0类,标签应为[1, 0, 0]; 若属于第1类,标签应为[0, 1, 0]; 若属于第2类,标签应为[0, 0, 1]
当batch_size为5时,这一个batch的标签的形状为[batch_size, classes],即shape为[5, 3]
'''
# 一个batch(batch_size=5)的标签
label = [[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
[1, 0, 0]]
'''
对于一个3分类问题,当训练时batch_size为5,
则深度网络对每一个batch的预测值的形状为[batch_size, classes],即shape为[5, 3]
以深度网络对第一个样本的预测值[0.5, 0.8, -1.2]为例,
经过Softmax层后,得到[0.3949, 0.5330, 0.0721],表示深度网络认为第一个样本属于0,1,2这三类的概率分别是0.3949,0.5330, 0.0721
'''
predict = [[ 0.5, 0.8, -1.2],
[-0.2, 1.8, 0.5],
[ 0.3, 0.2, 0.7],
[ 0.6,-0.8, -0.4],
[-0.4, 0.2, 0.8]]
# 当from_logits=true时,会先对predict进行Softmax运算,就无需在网络的最后添加Softmax层
loss_func = CategoricalCrossentropy(from_logits=True)
loss = loss_func(label, predict)
print(loss) # tf.Tensor(1.4376585, shape=(), dtype=float32)