tensorflow的saved_model存取模型

一种工程级方便的存取模型的方法,saved_model
通过存取一个简单的模型来作为示范
首先是模型定义

import tensorflow as tf
import numpy as np


W = tf.get_variable(name="demo", initializer=tf.ones([10, 32],dtype=tf.float32))
x = tf.placeholder(dtype=tf.float32, shape=[None, 10])

y = tf.matmul(x, W)
y_ = np.ones(shape=[10, 32], dtype=np.float32) # 使用np来创造两个label

cost = tf.nn.sigmoid_cross_entropy_with_logits(logits=y, labels=y_, name=None)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(cost)

这里定义了一个简单的矩阵乘, 然后我们来简单的训练几步

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    feed_dict = {x: np.ones([10, 10])}
    for i in range(100):
        sess.run(train_op, feed_dict=feed_dict)
    print(sess.run(y, feed_dict=feed_dict))

现在我们想把这个模型存储起来,传统的做法是用ckpt来做,现在tensorflow提供一种更强大简便的方法

首先构建两个字典,inputs 和 outputs, 把要存入的变量放入字典中
其中 tf.saved_model.utils.build_tensor_info是把变量变成可缓存对象的函数

    saved_model_dir = "save_model"
    signature_key = 'test_signature'
    input_key = 'input_x'
    output_key = 'output'

    # x 为输入tensor
    inputs = {input_key: tf.saved_model.utils.build_tensor_info(x)}
    # y 为最终需要的输出结果tensor
    outputs = {output_key: tf.saved_model.utils.build_tensor_info(y)}

然后把两个字典打包放入 signature 中

signature = tf.saved_model.signature_def_utils.build_signature_def(
        inputs=inputs,
        outputs=outputs,
        method_name=signature_key)

然后建立SavedModelBuilder,并以signature的形式添加要存储的变量

builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
builder.add_meta_graph_and_variables(
        sess=sess,
        tags=['test_saved_model'],
        signature_def_map={signature_key: signature},
        clear_devices=True)
    builder.save()

saved_model_dir 是要存模型的文件夹,可以是一个不存在的目录名,save之后,包括图结构,变量的内容,都会被存入到新创建的 saved_model_dir 目录内,下图就是存好的模型


下面我们来取出一个训练好的模型
用 tf.saved_model.loader.load 从 模型文件夹中取出模型
其中tags字段是['test_saved_model'], 与存模型时候指定的字段相同
把模型导入到session之后, 取出signature 就从signature中取出存入的变量了

saved_model_dir = "save_model"
signature_key = 'test_signature'
input_key = 'input_x'
output_key = 'output'

with tf.Session() as sess1:

    meta_graph_def = tf.saved_model.loader.load(sess1, ['test_saved_model'], saved_model_dir)
    signature = meta_graph_def.signature_def
    x_tensor_name = signature[signature_key].inputs[input_key].name
    y_tensor_name = signature[signature_key].outputs[output_key].name
    print(x_tensor_name)
    print(y_tensor_name)
    x = sess1.graph.get_tensor_by_name(x_tensor_name)
    y = sess1.graph.get_tensor_by_name(y_tensor_name)
    feed_dict = {x: np.ones([1, 10])}
    print(sess1.run(y, feed_dict=feed_dict))

我们看到,首先我们从signature 和 inputs/outputs都是一种字典的封装,把tensor_name存入到了字典中
传统的导入 需要用get_tensor_by_name , 这样就需要记录tensor的name熟悉,很麻烦。
通过signature,我们可以指定变量的别名,方便存取。

另外,存模型和变量的时候,会把全部的模型图存入,并不是只存我们指定几个变量,而signature只是方便我们存取想要使用的变量。

一个坑,使用tf.Session的时候,切记默认图和指定图的区别。tf.Session()会导入默认图的结构, 而导入模型是需要依附于sess的图, 在默认图中导入模型,如果默认图定义了其他计算图,会导致图冲突,模型导不进去。!!!

你可能感兴趣的:(tensorflow的saved_model存取模型)