解决accuracy_score报错Classification metrics can‘t handle a mix of continuous and multiclass targets

问题原因:label和predicted类型不一致

如果label是one-hot形式,需要在inputs_labels后面加.argmax(axis=1)进行反one-hot编码

acc = accuracy_score(inputs_labels.argmax(axis=1),predicted.numpy().argmax(axis=1))

此外,如果求交叉熵的时候,label是one-hot形式,使用:

loss_fn = tf.keras.losses.CategoricalCrossentropy()

如果不是one-hot形式,使用:

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

你可能感兴趣的:(机器学习)