Tf上保存变量

1.定义文件的保存路径

ckpt_dir="./ckpt_dir"

    ifnotos.path.exists(ckpt_dir):

    os.makedirs(ckpt_dir)

2.定义一个全局变量

global_step=tf.Variable(0,name='global_step',trainable=False)

这个全局变量是保存文件和提取文件的标识,比如我现在要load什么时候保存的变量

3.定义saver方法

saver=tf.train.Saver()

注意任何变量定义在saver前面的都会被保存,在其后面定义的都不会被保存

4.保存变量

注意看前面定义的变量global_step,第一步给这个变量更新值(epoch),然后再保存。所以这个变量是以后load哪个文件的依据

global_step.assign(i).eval()#set and update(eval) global_step with index, i

saver.save(sess, ckpt_dir+"/model.ckpt",global_step=global_step)

5.load变量

ckpt=tf.train.get_checkpoint_state(ckpt_dir)

if ckpt and ckpt.model_checkpoint_path:

    print(ckpt.model_checkpoint_path)

    saver.restore(sess, ckpt.model_checkpoint_path)#restore all variables

你可能感兴趣的:(Tf上保存变量)