Tensorflow深度学习实战之(五)--保存与恢复模型

文章目录

  • 一、保存模型
  • 二、恢复模型
  • 三、使用模型预测


一、保存模型

在训练完Tensorflow模型为了方便对新的数据进行预测需要保存该模型,Tensorflow提供 tf.train.Saver() 函数建立一个saver对象,然后在会话中调saver的save()函数即可将模型保存起来,形式为:

saver.save(sess, save_path, global_step=None),主要用来保存模型其参数意义为:
sess:保存模型要求必须有一个加载了计算图的会话
save_path:模型保存路径及保存名称
global_step:如果提供的话这个数字会添加到save_path后面,用于区分不同训练阶段的结果

保存后的的模型会在指定文件夹下生成四个文件,每个文件的名称及作用如下:

  1. checkpoint:包含所有权重weights,偏置biases,梯度gradients和所有其他保存的变量variables的二进制文件;
  2. data file:保存了模型的所有变量的值;
  3. meta file:保存了graph结构,当存在meta file,可以不在文件中定义模型,也可以运行;
  4. index file:一个键值对列表,列表的key值为tensor名,列表的value为BundleEntryProto

实现两个矩阵的加法运算并保存模型的程序案例为:

import tensorflow as tf

m1 = tf.Variable(tf.constant([[1.0, 3.0], [2.0, 4.0]], shape=[2, 2]), name="m1")
# 定义张量m2
m2 = tf.Variable(tf.constant([[2.0, 7.0], [3.0, 8.0]], shape=[2, 2]), name="m2")
# 实现两个张量求和
result = m1 + m2
# 建立一个保存器对象
saver = tf.train.Saver()
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print("result:", sess.run(result))
 # 实现对象的保存,在当前文件路径创建的model文件夹下名称为model.ckpt的文件
 saver.save(sess, "model/model.ckpt")

其输出的文件结构和结果为:
Tensorflow深度学习实战之(五)--保存与恢复模型_第1张图片

result: [[ 3. 10.]
 [ 5. 12.]]

二、恢复模型

将模型保存好以可以在创建的session会话中调用saver的restor()函数从指定的路径找到模型文件,形式为:

save.restore(sess, save_path),主要用来从会话中恢复模型其参数意义为:
sess:用以恢复参数模型的会话
save_path:以保存模型的路径通常包含模型名字。

重新定义计算图上的节点,使用restore加载模型,在会话中会打印之前模型里面的输出结果

import tensorflow as tf

v1 = tf.Variable(tf.constant([[5.0, 6.0], [7.0, 7.0]], shape=[2, 2]), name="m1")
v2 = tf.Variable(tf.constant([[4.0, 6.0], [7.0, 8.0]], shape=[2, 2]), name="m2")
result = v1 + v2
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
 sess.run(init_op)
 saver.restore(sess, "model/model.ckpt")
 print(sess.run(result))

输出结果为:

[[ 3. 10.]
 [ 5. 12.]]

当然也可以使用 tf.train.import_meta_graph(save_path) 函数直接加载已经持久化的计算图结构,而不加载参数

import tensorflow as tf

saver = tf.train.import_meta_graph("model/model.ckpt.meta") #直接加载持久化的图文件
with tf.compat.v1.Session()as sess:
 saver.restore(sess, 'model/model.ckpt') #将参数加载到模型中
 print("m1", sess.run(tf.get_default_graph().get_tensor_by_name('m1:0'))) #获取m1的节点
 print("m2", sess.run(tf.get_default_graph().get_tensor_by_name('m2:0'))) #获取m2的节点
 print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0'))) #获取add节点得到的结果

输出结果为:

m1 [[1. 3.]
 [2. 4.]]
m2 [[2. 7.]
 [3. 8.]]
[[ 3. 10.]
 [ 5. 12.]]

三、使用模型预测

使用模型预测步骤:
1、占位符构建计算图并保存模型

mport tensorflow as tf

# 分别定义2行3列的占位符变量
x = tf.placeholder(shape=(2, 3), dtype=tf.float32, name='x')
y = tf.placeholder(shape=(2, 3), dtype=tf.float32, name='y')
# 定义一个2行3列的变量
b = tf.Variable(tf.ones([2, 3]), dtype=tf.float32, name='b')
# 实现两个矩阵的点成
mul_result = tf.multiply(x, y, name='mul_result')
# 实现两个张量的相加
add_result = tf.add(mul_result, b, name='add_result')
saver = tf.train.Saver()
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 # 保存模型
 saver.save(sess, "predict/predict_model.ckpt")

2、加载模型喂入数据进行预测

import tensorflow as tf

# 加载计算图,不加载参数
saver = tf.train.import_meta_graph('predict/predict_model.ckpt.meta')
with tf.Session()as sess:
# 加载x节点
 input_x = sess.graph.get_tensor_by_name('x:0')
# 加载y节点
 input_y = sess.graph.get_tensor_by_name('y:0')
# 获得矩阵相乘操作
 mul_result = sess.graph.get_tensor_by_name('mul_result:0')
# 向节点喂入数据,获得输出结果
 result = sess.run(mul_result, feed_dict={input_x: [[2, 3, 4], [2, 3, 4]], input_y: [[1, 2, 3], [3, 5, 5]]})
 print("矩阵乘法结果:", result)

你可能感兴趣的:(深度学习,tensorflow,机器学习)