Tensorflow 数据保存(2)

1. saver=tf.train.Saver (tf.global_variables(),max_to_keep)
(1)max_to_keep
这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

(2)tf.global_variables()
只保存tf.global_variables()里的这些变量,如果saver=tf.train.Saver()里面不传入参数,默认保存全部变量

weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.save(sess,'model.ckpt')


2. saver.save ()
创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

3. saver.restore ()
模型的恢复用的是restore()函数,它需要两个参数restore(sess, save_path),save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint ()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

你可能感兴趣的:(Tensorflow 数据保存(2))