sample_weight argument is not supported when using dataset as input 解决方案

当使用Keras 调用fit进行训练模型时,有时需要传入sample_weight参数。对于传入fit的为普通的输入输出序列数据时没有问题。但是当传入的数据时Datasent时就会报错。解决方案如下:

https://keras.io/api/models/model_training_apis/

直接把sample_weight 作为Dataset的第三维!

dataset = tf.data.Dataset.from_tensor_slices((input, output, weights))
dataset = dataset.batch(32)
model.fit(dataset, epochs=10,steps_per_epoch=10) 

你可能感兴趣的:(【——机器学习相关——】)