Tensorflow中CategoricalCrossentropy()和SparseCategoricalCrossentropy()的区别

今天在调试testset代码是报如下错:

ValueError: Can not squeeze dim[2], expected a dimension of 1, got 5 for '{{node Squeeze}} = Squeeze[T=DT_FLOAT, squeeze_dims=[-1]](strided_slice_1)' with input shapes: [32,10,5].

debug了一会才发现是loss函数的问题,这个问题之前也碰到过,也总结了,但是没有形成文字记录,现在重新总结一下。

Tensorflow中关于loss的关键一步为:

loss_object = tf.keras.losses.CategoricalCrossentropy()
# 或
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

包括在之后train_loss和test_loss中使用的metrics

train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
# 或
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

我们可以在多分类任务中出现两种loss的交叉熵,这两种有什么区别呢,其中使用的参数又是什么意思呢?

加Sparse和不加Sparse区别

Sparse 的含义是:真实的标签值y_true可以直接传入 int 类型的标签类别,即添加Sparse时 y 不需要one-hot,而不添加Sparse时需要进行独热编码(自己实现)。
也就是说如果我们没有在自己的labels做预处理实现one-hot,在之后的loss函数中需要使用SparseCategoricalCrossentropy(),如果我们已经自己实现了one-hot,那么就使用CategoricalCrossentropy(),添加Sparse与否取决于我们的数据集以及数据预处理。

这两种loss函数的参数含义

# CategoricalCrossentropy()
def __init__(self,
             from_logits: bool = False,
             label_smoothing: int = 0,
             reduction: Any = losses_utils.ReductionV2.AUTO,
             name: str = 'categorical_crossentropy') -> None
# SparseCategoricalCrossentropy()
def __init__(self,
             from_logits: bool = False,
             reduction: Any = losses_utils.ReductionV2.AUTO,
             name: str = 'sparse_categorical_crossentropy') -> None

from_logits:bool类型,默认False,y_pred 是否预期为 logits 张量。默认情况下,我们假设 y_pred 对概率分布进行编码,也就是我们需要提前对输出做softmax,如果我们没有在model中实现这里需要设置为True。
label_smoothing: 浮点数在 [0, 1] 中。当 > 0 时,标签值会被平滑,这意味着标签值的置信度会放松。例如,如果 0.1 ,则将 0.1 / num_classes 用于非目标标签,将 0.9 + 0.1 / num_classes 用于目标标签。
axis :计算交叉熵的轴(特征轴)。默认为 -1。
reduction :类型tf.keras.losses.Reduction适用于损失。默认值为AUTO.AUTO表示缩减选项将由使用上下文确定。对于几乎所有情况,这默认为SUM_OVER_BATCH_SIZE.当与tf.distribute.Strategy,在内置训练循环之外,例如tf.keras compile和fit, 使用AUTO或者SUM_OVER_BATCH_SIZE将引发错误。
name :实例的可选名称。默认为’categorical_crossentropy’。

你可能感兴趣的:(Tensorflow,tensorflow,python,深度学习)