1 TensorFlow报错
报错信息:
2 报错原因
字面原因:
这个问题是由于输出层的类别数和训练数据shape不同导致。
底层原因:
Step1 : 代码中,我通过ImageDataGenerator
函数获取的图像生成器,会自动将图像label转为one-hot编码格式
train_image_generator = ImageDataGenerator(rescale=1./255, horizontal_flip=True)
val_image_generator = ImageDataGenerator(rescale=1./255)
train_data_gen = train_image_generator.flow_from_directory(directory = train_dir,
batch_size = batch_size,
shuffle=True,
target_size = (im_height, im_width),
class_mode=’categorical’)
val_data_gen = val_image_generator.flow_from_directory(directory = val_dir,
batch_size = batch_size,
shuffle=False,
target_size = (im_height, im_width),
class_mode=’categorical’)
train_imgs_batch, train_labels_batch = next(train_data_gen)
print(train_labels_batch[:5])
输出:
[[0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0.]
[0. 0. 0. 0. 1.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]]
Step2 : 而在构造模型的loss函数和accuracy计算方法时,分别采用了SparseCategoricalCrossentropy
和SparseCategoricalAccuracy
。
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name=’train_accuracy’)
而在TensorFlow官方文档有关tf.keras.losses.CategoricalCrossentropy函数中有说明:
accuracy也有类似说明:
输入的label经过了one hot编码,但是loss和accuracy却调错,使用了不采用one-hot编码的SparseCategoricalCrossentropy和SparseCategoricalAccuracy。
3 解决方法
直接改成对应的loss函数CategoricalCrossentropy和CategoricalAccuracy即可。