【keras】多GPU训练ModelCheckpoint()保存模型

使用多GPU训练时需要保存的是在CPU上创建的单模型,如果直接使用ModelCheckpoint会报错,需要重写ModelCheckpoint函数,参考自

from keras.callbacks import ModelCheckpoint
from keras.utils import multi_gpu_model
class ParallelModelCheckpoint(ModelCheckpoint):
    def __init__(self,model,filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        self.single_model = model
        super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

    def set_model(self, model):
        super(ParallelModelCheckpoint,self).set_model(self.single_model)

single_model = Model(...)
paralle_model = multi_gpu_model(single_model , gpus=4)
paralle_model.fit(..., 
				  callbacks=[ParallelModelCheckpoint(single_model ,'best.hd5')])

你可能感兴趣的:(Tensorflow)