在使用 Keras 中的 load_model
函数重新载入模型的时,会出现如下的报错
Traceback (most recent call last):
File "test_unet.py", line 79, in <module>
model = load_model(weight_path)
File "/usr/local/lib/python2.7/dist-packages/keras/models.py", line 274, in load_model
sample_weight_mode=sample_weight_mode)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py", line 636, in compile
loss_function = losses.get(loss)
File "/usr/local/lib/python2.7/dist-packages/keras/losses.py", line 122, in get
return deserialize(identifier)
File "/usr/local/lib/python2.7/dist-packages/keras/losses.py", line 114, in deserialize
printable_module_name='loss function')
File "/usr/local/lib/python2.7/dist-packages/keras/utils/generic_utils.py", line 164, in deserialize_keras_object
':' + function_name)
ValueError: Unknown loss function:dice_coef_loss
可以看到函数发生错误的地方可以追溯到 load_model
位置,分析提醒可以发现,是因为 Keras 找不到名为 dice_coef_loss 的损失函数。这个损失函数是我在函数训练过程中自定义的损失函数,具体如下
# parameter for loss function
smooth = 1.
# metric function and loss function
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
在这里我自定义了一个指标 dice_coef
和一个损失函数 dice_coef_loss
。因为使用 model.save(filepath)
得到的会保存训练的损失函数,但是这个损失函数在 Keras 中的 losses.py 是找不到的是,所以才会报这样的错。
首先可以看一下函数 load_model
的源码,在这里只给出说明部分如下
def load_model(filepath, custom_objects=None, compile=True):
"""Loads a model saved via `save_model`.
# Arguments
filepath: String, path to the saved model.
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
compile: Boolean, whether to compile the model
after loading.
# Returns
A Keras model instance. If an optimizer was found
as part of the saved model, the model is already
compiled. Otherwise, the model is uncompiled and
a warning will be displayed. When `compile` is set
to False, the compilation is omitted without any
warning.
# Raises
ImportError: if h5py is not available.
ValueError: In case of an invalid savefile.
"""
其中的 custom_objects
是可选的字典,在反序列化过程中映射名称(字符串)到要考虑的自定义类或函数,所以可以直接通过字典来制定缺失的指标或者损失函数,如下
# parameter for loss function
smooth = 1.
# metric function and loss function
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
# load model
weight_path = './weights.h5'
model = load_model(weight_path,custom_objects={'dice_coef_loss': dice_coef_loss,'dice_coef':dice_coef})
重点看上面代码的最后一行,通过字典指定我们自定义的函数(或许是一个指标,或许是一个损失函数)就可以解决上面的问题。
[1] Bisgates Github https://github.com/keras-team/keras/issues/5916#issuecomment-300038263