(tensorflow)模型的保存和载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow提供了两种保存模型的方式,一种是使用tf.train.Saver函数来保存TensorFlow程序的参数和完整的模型结构,保存的文件后缀为 “.ckpt”;另一种方式是将计算图保存在一个 “.pb” 文件中。

使用tf.train.saver()进行模型的保存

保存

(tensorflow)模型的保存和载入_第1张图片
在创建Saver对象时,有一个参数我们经常会用到,就是max_to_keep参数,这个是用来设置保存模型的个数,默认为5,即max_to_keep=5,保存最近的5个模型。如果想每训练一个epoch就保存一次模型,则可以将max_to_keep设置为None或0,如:

saver = tf.train.Saver(max_to_keep=None)

在创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess, 'ckpt/mnist.ckpt', global_step=global_step)

第三个参数将训练的次数作为后缀加入到模型中:

saver.save(sess, 'my_model', global_step=0)          # filename: my-model-0
saver.save(sess, 'my_model', global_step=1000)    # filename: my-model-1000

其中需要注意的是:

  1. 变量在定义的时候必须要指定 ‘name’ 参数的值,因为在载入模型时需要根据这个值来提取对应的 Tensor 或 Operator。
  2. 在saver.save的时候,每个checkpoint会保存三个文件,如:
    my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
    其中 ‘.meta’ 后缀的文件保存的是模型的计算图,其余保存的是模型的参数和权值。

载入

(tensorflow)模型的保存和载入_第2张图片
tf.train.import_meta_graph() 方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

将tensorflow的模型网络导出为单个文件

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。
我们知道,graph_def文件中没有包含网络中的Variable值(通常情况权重定义为Variable),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

保存

在保存模型之前,需要先将网络的权重进行冻结,使用:
tensorflow.python.framework.graph_util.convert_variables_to_constant()

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
 
# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')
 
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
    constant_graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
    with tf.gfile.FastGFile(pd_file_path, mode = 'wb') as f:
          f.write(constant_graph.SerializeToString())

载入

with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        pb_file_path="catdog.pb"
        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read()) #rb
            _ = tf.import_graph_def(output_graph_def, name="")

    with tf.Session() as sess:
            tf.global_variables_initializer().run()
 
            input_x = sess.graph.get_tensor_by_name("input:0") ####这就是刚才取名的原因
            print (input_x)
            out_softmax = sess.graph.get_tensor_by_name("softmax:0")
            print (out_softmax)
            out_label = sess.graph.get_tensor_by_name("output:0")
            print (out_label)
 
            img = np.array(Image.open(jpg_path).convert('L'))
            img = transform.resize(img, (H, W, 3))
 
            plt.figure("fig1")
            plt.imshow(img)
            img = img * (1.0 /255)
            img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, H, W, 3])})

(tensorflow)模型的保存和载入_第3张图片

你可能感兴趣的:(计算机视觉,TensorFlow,tensorflow)