问题描述
在使用 callbacks.ModelCheckpoint()
并进行多 gpu 并行计算时,callbacks
函数会报错:
TypeError: can't pickle ...(different text at different situation) objects
这个错误形式其实跟使用多 gpu 训练时保存模型不当造成的错误比较相似:
To save the multi-gpu model, use
.save(fname)
or.save_weights(fname)
with the template model (the argument you passed tomulti_gpu_model
),
rather than the model returned bymulti_gpu_model
.
这个问题在我之前的文章中也有提到:[Keras] 使用Keras调用多GPU,并保存模型
。显然,在使用检查点时,默认还是使用了 paralleled_model.save()
,进而导致错误。为了解决这个问题,我们需要自己定义一个召回函数。
解决方法
法一
original_model = ...
parallel_model = multi_gpu_model(original_model, gpus=n)
class MyCbk(keras.callbacks.Callback):
def __init__(self, model):
self.model_to_save = model
def on_epoch_end(self, epoch, logs=None):
self.model_to_save.save('model_at_epoch_%d.h5' % epoch)
cbk = MyCbk(original_model)
parallel_model.fit(..., callbacks=[cbk])
法二
class ParallelModelCheckpoint(ModelCheckpoint):
def __init__(self,model,filepath, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1):
self.single_model = model
super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)
def set_model(self, model):
super(ParallelModelCheckpoint,self).set_model(self.single_model)
check_point = ParallelModelCheckpoint(single_model ,'best.hd5')
法三
class CustomModelCheckpoint(keras.callbacks.Callback):
def __init__(self, model, path):
self.model = model
self.path = path
self.best_loss = np.inf
def on_epoch_end(self, epoch, logs=None):
val_loss = logs['val_loss']
if val_loss < self.best_loss:
print("\nValidation loss decreased from {} to {}, saving model".format(self.best_loss, val_loss))
self.model.save_weights(self.path, overwrite=True)
self.best_loss = val_loss
model.fit(X_train, y_train,
batch_size=batch_size*G, epochs=nb_epoch, verbose=0, shuffle=True,
validation_data=(X_valid, y_valid),
callbacks=[CustomModelCheckpoint(model, '/path/to/save/model.h5')])
参考资料
- Multi_gpu in keras not working with callbacks, but works fine if callback is removed #8649
- call_back_error when using multi_gpu_model #10218
- Keras同时用多张显卡训练网络
- ModelCheckpoint callback with multi_gpu fails to save the model, throws error after 1st epoch #8764