1. 注意:如果保存模型和加载模型在两个.py文件中,在Spyder的Ipython console中保存完模型后,v1是'Variable:0',v2是'Variable_1:0',紧接着运行加载模型,此文件中的v1,v2就是'Variable_2:0'和'Variable_3:0',加载会出错,需关掉Ipython console后,重新运行使得modelLoad.py中的v1和v2分别是'Variable:0'和'Variable_1:0'才能运行成功。
保存模型
import tensorflow as tf
#声明两个变量并计算它们的和
v1 = tf.Variable(tf.constant(1.0,shape=[1],name='v1'))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name='v2'))
result = v1 + v2
init_op = tf.global_variables_initializer()
#声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
#将模型保存到model/model.ckpt文件
saver.save(sess,'model/model1/model.ckpt')
加载模型
import tensorflow as tf
#使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.constant(1.0,shape=[1],name='v1'))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name='v2'))
result = v1 + v2
#声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()
with tf.Session() as sess1:
#加载已经保存的模型,并通过已经保存的模型中变量的值来计算加法
saver.restore(sess1,'model/model1/model.ckpt')
print(sess1.run(result))
runfile('F:/python学习201712/TensorFlow/modelLoad.py', wdir='F:/python学习201712/TensorFlow')
INFO:tensorflow:Restoring parameters from model/model1/model.ckpt
[ 3.]
2. 如果不希望重复定义图上的运算,也可以直接加载已持久化的图
import tensorflow as tf
#如果不希望重新定义图上的运算,也可以直接加载已经持久化的图
saver = tf.train.import_meta_graph('model/model1/model.ckpt.meta')
with tf.Session() as sess:
saver.restore(sess,'model/model1/model.ckpt')
#通过张量的名称来获取张量
print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
runfile('F:/python学习201712/TensorFlow/modelLoad2.py', wdir='F:/python学习201712/TensorFlow')
INFO:tensorflow:Restoring parameters from model/model1/model.ckpt
[ 3.]
3. tf.train.Saver支持在保存或者加载时给变量重命名
import tensorflow as tf
"""为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存
或者加载的变量。"""
#这里声明的变量名称和已经保存的模型中变量的名称不同
v1 = tf.Variable(tf.constant(1.0,shape=[1],name='other_v1'))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name='other_v2'))
#如果直接使用tf.train.Saver()来加载模型会报变量找不到的错误。
#使用一个字典来重命名变量就可以加载原来的模型了。这个字典指定了原来名称为v1的变量现在
#加载到变量v1中(名称为other-v1),名称为v2的变量加载到变量v2中(名称为other-v2)
saver = tf.train.Saver({'v1':v1,'v2':v2})
4. 保存和加载滑动平均模型
import tensorflow as tf
"""保存滑动平均模型"""
v = tf.Variable(0, dtype=tf.float32, name='v')
#在没有申明滑动平均模型时只有一个变量v,所以下面的语句只会输出'v:0'
for variables in tf.global_variables():
print(variables.name)
print('='*60)
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.global_variables())
#在申明滑动平均模型之后,Tensorflow会自动生成一个影子变量
#v/ExponentialMoving Average,于是下面的语句输出
# 'v:0'和'v/ExponentialMovingAverage:0'
for variables in tf.global_variables():
print(variables.name)
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
sess.run(tf.assign(v,10))
sess.run(maintain_average_op)
#保存时,Tensorflow会将v:0和v/ExponentialMovingAverage:0两个变量都存下来
saver.save(sess,'model/model2/model.ckpt')
print(sess.run([v,ema.average(v)]))
#通过变量重命名直接读取变量的滑动平均值
import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name='v')
#通过变量重命名将原来变量v的滑动平均值直接赋值给v
saver = tf.train.Saver({'v/ExponentialMovingAverage':v})
with tf.Session() as sess:
saver.restore(sess,'model/model2/model.ckpt')
print(sess.run(v))
#为了加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了
#variables_to_restore函数来生成tf.train.Saver类所需要的变量重命名字典。
ema = tf.train.ExponentialMovingAverage(0.99)
#通过使用variables_to_restore函数可以直接生成上面代码中提供的字典
# {'v/ExponentialMovingAverage':v}
#以下代码会输出
#{'v/ExponentialMovingAverage':}
#其中后面的Variable类就代表了变量v
print(ema.variables_to_restore())
saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
saver.restore(sess,'model/model2/model.ckpt')
print(sess.run(v))
runfile('F:/python学习201712/TensorFlow/modelLoad3.py', wdir='F:/python学习201712/TensorFlow')
INFO:tensorflow:Restoring parameters from model/model2/model.ckpt
0.0999999
{'v/ExponentialMovingAverage': }
INFO:tensorflow:Restoring parameters from model/model2/model.ckpt
0.0999999