TensorFlow和Keras的模型保存及载入模型参数继续训练

TensorFlow

在TensorFlow中,模型的持久化保存和加载主要通过Saver()。
在初次训练之后调用如下的save函数保存,然后,在预测前,或者在继续训练前调用load加载参数即可。

def __init__():
	self.sess = tf.Session()
	# 定义好网络结构...
	self.sess.run(tf.global_variables_initializer())
def check_path(self, path):
    if not os.path.exists(path):
        os.mkdir(path)
def save(self):
    self.check_path('model')
    saver=tf.train.Saver(tf.global_variables(),max_to_keep=10)
    print("model: ",saver.save(self.sess,'model/modle.ckpt'))

def load(self):
    saver=tf.train.Saver(tf.global_variables())
    module_file = tf.train.latest_checkpoint('model')
    saver.restore(self.sess, module_file)

Keras

Keras 和 TensorFlow一样,模型预测和加载参数继续训练是一样的:

首先在模型训练好之后进行模型的保存,当然,直接使用 model.save(‘model.h5’) 也是一样的。

在模型继续训练或者进行测试之前加载原来训练好的参数即可:

def save(self):
    self.actor.save_weights('model/ddpg_actor.h5')
    self.critic.save_weights('model/ddpg_critic.h5')
    
def load(self):
     if os.path.exists('model/ddpg_actor.h5') and os.path.exists('model/ddpg_critic.h5'):
         self.actor.load_weights('model/ddpg_actor.h5')
         self.critic.load_weights('model/ddpg_critic.h5')

if __name__ == '__main__':
    model = DDPG()
    if for_train:
        model.load()
        model.train()
    if for_test:
        model.load()
        model.play()

模型参数的加载,相比TensorFlow,keras还是更加便捷的,惊叹于keras的简单易用

你可能感兴趣的:(深度学习)