tensorflow2.0使用自带的函数求精准率和召回率(解决Shapes (None, 10) and (None, 1) are incompatible)

本代码使用的是cifar10数据集,所以有十个类别
废话不多说,直接给代码吧

import tensorflow as tf
from tensorflow.keras import datasets, Sequential, layers,metrics
(x_train, y_train), _ = datasets.cifar10.load_data()

def procession(x, y):
    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.squeeze(y)
    y = tf.one_hot(y, depth=10)

    return x, y
model = Sequential([
    layers.Flatten(input_shape=(32, 32, 3)),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(32, activation='relu'),
    layers.Dense(10, activation='softmax')
])
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1000).map(procession).batch(128)
# model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=['accuracy'])
model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=[metrics.Recall()])
# model.compile(loss=tf.losses.binary_crossentropy, optimizer='adam', metrics=[metrics.Precision()])
model.fit(train_db, epochs=5)

出现Shapes (None, 10) and (None, 1) are incompatible的原因是:
x通过模型之后会得到一个shape为(None,10)的数据, 而y因为没有进行one_hot编码,y.shape=(None, 1),形状不同所以不能进行计算
没有对标签y_train进行one_hot编码,但是单单进行one_hot编码也是不够的,因为进行one_hot编码之后y_train.shape = (None, 1, 10)就会报错Shapes (None, 10) and (None, 1,10) are incompatible, 所以在对y_train进行处理时,通过tf.squeeze(y_train)是的y_train.shape = (None, 10),这样子就可以进行计算了。

你可能感兴趣的:(tensorflow,深度学习,人工智能,机器学习)