Tensorflow 模型持久化

保存变量

这里的保存两个变量,然后另外从另一个文件读取

import tensorflow as tf
from prepare import Prepare
Prepare()

v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables():
    print(variables.name) #v:0
print("---")
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
    print(variables.name)
    #v:0 v/ExponentialMovingAverage:0

saver = tf.train.Saver() 
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的时候会将v:0  v/ExponentialMovingAverage:0这两个变量都存下来。
    saver.save(sess, "Saved_model/model2.ckpt")
    print (sess.run([v, ema.average(v)]))#10.0, 0.099999905

加载变量

import tensorflow as tf
from prepare import Prepare
Prepare()

v = tf.Variable(0, dtype=tf.float32, name="v")
m = tf.Variable(1, dtype=tf.float32, name="g")

# 通过变量重命名将原来变量v的滑动平均值直接赋值给v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v, "v": m})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print(sess.run(v)) #0.0999999
    print(sess.run(m)) #10.0

你可能感兴趣的:(tensorflow)