如何使用Tensorflow保存或者加载模型(二) -- ModelBuilder API

1.背景

在上一篇如何使用Tensorflow保存或者加载模型(一)文章中,站长介绍了怎么把Tensorflow模型的图和变量通过tf.train.Saver()保存在本地。在这一篇文章中,站长会介绍用一种新的模型保存和加载的方式,ModelBuilder API,在该方式下保存和加载模型会更加简单,而且支持Python和Java环境下运行,可以更好地满足工业界的需求。

1.1 模型文件介绍

ModelBuilder API会生成saved_model.pb的文件和variables的文件夹。

saved_model.pb 中的后缀pb代表protobuf,在Tensorflow中这个pb文件包含了模型图的定义和模型的权重,也是模型保存的核心文件。

variables 文件夹中包含的是变量的数据和索引文件。

Tensorflow的pb文件及变量目录

1.2 模型的保存示例代码

我们这里仍然使用linear regression模型进行演示,使用tf.saved_model.simple_save进行模型保存。

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np

##1.创建PlaceHolder和初始化参数##
X = tf.placeholder("float", name="X")
Y = tf.placeholder("float", name="Y")

W = tf.Variable(np.random.randn(), name= "W")
b = tf.Variable(np.random.randn(), name= "b")

learning_rate = 0.02
epochs = 100

data_x = np.linspace(0, 50, 50)
data_y = np.linspace(0, 50, 50)

##2.实现梯度下降##
y_pred = tf.add(tf.multiply(X, W), b, name="y_pred")
loss = tf.reduce_sum(tf.pow(y_pred-Y, 2)) / (2 * len(data_x))
opt = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
##初始化变量##
init = tf.global_variables_initializer()

##3.构建Tensorflow Session##
with tf.Session() as sess:
    sess.run(init)

    for epoch in range(epochs):
        for (batch_x, batch_y) in zip(data_x, data_y):
            sess.run(opt, feed_dict={X: batch_x, Y: batch_y})

        if (epoch + 1) % 10 == 0:
            cost = sess.run(loss, feed_dict={X:data_x, Y:data_y})
            print("Epoch", (epoch + 1), ": cost =", cost, "W =", sess.run(W), "b =", sess.run(b))


    # 存储必须的变量#
    training_cost = sess.run(loss, feed_dict={X:data_x, Y:data_y})
    weight = sess.run(W)
    bias = sess.run(b)

    # 用变量进行预测#
    predictions = weight * X + bias
    print("Training Cost =", training_cost, "Weight =", weight, "bias =", bias, '\n')
    print("预测结果:", weight * 0.01 + bias)

    # 保存模型#
    input_dict = {
        "X": X,
        "Y": Y
    }

    output_dict = {
        "y_pred": y_pred
    }

    tf.saved_model.simple_save(sess, "./result/model_save", input_dict, output_dict)
    print("保存模型")

运行结果如下:

Epoch 10 : cost = 5.5811808e-05 W = 0.9996358 b = 0.018207774
Epoch 20 : cost = 5.1781208e-05 W = 0.9996493 b = 0.017536966
Epoch 30 : cost = 4.8034057e-05 W = 0.99966216 b = 0.016890889
Epoch 40 : cost = 4.455984e-05 W = 0.9996746 b = 0.016268639
Epoch 50 : cost = 4.1336203e-05 W = 0.9996866 b = 0.015669275
Epoch 60 : cost = 3.8349073e-05 W = 0.99969816 b = 0.015091987
Epoch 70 : cost = 3.5585737e-05 W = 0.99970937 b = 0.014535968
Epoch 80 : cost = 3.2999385e-05 W = 0.99972 b = 0.014000463
Epoch 90 : cost = 3.062137e-05 W = 0.99973035 b = 0.013484693
Epoch 100 : cost = 2.8402981e-05 W = 0.99974024 b = 0.0129879005
Training Cost = 2.8402981e-05 Weight = 0.99974024 bias = 0.0129879005 

预测结果: 0.022985302954912187
保存模型

1.3 模型的加载示例代码(Python)

在Python环境下,加载PB文件也是非常方便,使用tf.saved_model.loader.load一行代码就完成加载了,使用方法一或者方法二都可以。

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
    tf.saved_model.loader.load(sess, [tag_constants.SERVING], "./result/model_save")

    # 方法1:
    # X = graph.get_tensor_by_name("X:0")
    # y_pred = graph.get_tensor_by_name("y_pred:0")
    # print(sess.run(y_pred, feed_dict={ X:0.01}))

    # 方法2:
    print(sess.run("y_pred:0", feed_dict={"X:0": 0.01}))

从打印的结果看,预测结果与训练结束时的结果一致,符合预期

##结果一致
0.022985302

1.4 模型的加载示例代码(Java)

在工业界中,很多公司是以Java为主要开发语言,因为考虑到应用部署的语言兼容性,我们这里也提供了Java环境下的模型加载方式。

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.IOException;
import java.nio.FloatBuffer;

public class model_restore_sample {

    public static void main(String[] args) throws IOException {

        int num_predict = 1;

        String path = "../tensorflow_model_save_restore/result/model_save";

        try(SavedModelBundle savedModelBundle = SavedModelBundle.load(path, "serve");){

            //创建Session
            Session sess = savedModelBundle.session();

            //创建Input Tensor,输入新样本的值是0.01,shape为1
            Tensor X = Tensor.create(new long[] {num_predict},
                                     FloatBuffer.wrap(new float[] {0.01f})
                    );

            //输入Input,获得预测结果
            float[] result = sess.runner().feed("X", X).fetch("y_pred").run().get(0).copyTo(new float[num_predict]);


            //打印结果
            System.out.println(result[0]);
        }

    }
}

结果与Python版本的预测结果一致

0.022985302

2.总结

本文介绍了如何使用ModelBuilder API去保存Tensorflow模型,并提供了Python版本和Java版本的模型加载方式,这些方式更简单而且兼容性更强。也许有的小伙伴会问,既然模型文件已经是固定的话,能否将其做成一个服务,直接用于线上预测呢?答案当然是可以的,后续的文章会给大家介绍一种工业界非常流行的模型部署方式Tensorflow Serving,可以支持大规模线上的调用。

你可能感兴趣的:(如何使用Tensorflow保存或者加载模型(二) -- ModelBuilder API)