1.背景
在上一篇如何使用Tensorflow保存或者加载模型(一)文章中,站长介绍了怎么把Tensorflow模型的图和变量通过tf.train.Saver()
保存在本地。在这一篇文章中,站长会介绍用一种新的模型保存和加载的方式,ModelBuilder API
,在该方式下保存和加载模型会更加简单,而且支持Python和Java环境下运行,可以更好地满足工业界的需求。
1.1 模型文件介绍
ModelBuilder API会生成saved_model.pb
的文件和variables
的文件夹。
saved_model.pb 中的后缀pb代表protobuf,在Tensorflow中这个pb文件包含了模型图的定义和模型的权重,也是模型保存的核心文件。
variables 文件夹中包含的是变量的数据和索引文件。
1.2 模型的保存示例代码
我们这里仍然使用linear regression
模型进行演示,使用tf.saved_model.simple_save
进行模型保存。
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
##1.创建PlaceHolder和初始化参数##
X = tf.placeholder("float", name="X")
Y = tf.placeholder("float", name="Y")
W = tf.Variable(np.random.randn(), name= "W")
b = tf.Variable(np.random.randn(), name= "b")
learning_rate = 0.02
epochs = 100
data_x = np.linspace(0, 50, 50)
data_y = np.linspace(0, 50, 50)
##2.实现梯度下降##
y_pred = tf.add(tf.multiply(X, W), b, name="y_pred")
loss = tf.reduce_sum(tf.pow(y_pred-Y, 2)) / (2 * len(data_x))
opt = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
##初始化变量##
init = tf.global_variables_initializer()
##3.构建Tensorflow Session##
with tf.Session() as sess:
sess.run(init)
for epoch in range(epochs):
for (batch_x, batch_y) in zip(data_x, data_y):
sess.run(opt, feed_dict={X: batch_x, Y: batch_y})
if (epoch + 1) % 10 == 0:
cost = sess.run(loss, feed_dict={X:data_x, Y:data_y})
print("Epoch", (epoch + 1), ": cost =", cost, "W =", sess.run(W), "b =", sess.run(b))
# 存储必须的变量#
training_cost = sess.run(loss, feed_dict={X:data_x, Y:data_y})
weight = sess.run(W)
bias = sess.run(b)
# 用变量进行预测#
predictions = weight * X + bias
print("Training Cost =", training_cost, "Weight =", weight, "bias =", bias, '\n')
print("预测结果:", weight * 0.01 + bias)
# 保存模型#
input_dict = {
"X": X,
"Y": Y
}
output_dict = {
"y_pred": y_pred
}
tf.saved_model.simple_save(sess, "./result/model_save", input_dict, output_dict)
print("保存模型")
运行结果如下:
Epoch 10 : cost = 5.5811808e-05 W = 0.9996358 b = 0.018207774
Epoch 20 : cost = 5.1781208e-05 W = 0.9996493 b = 0.017536966
Epoch 30 : cost = 4.8034057e-05 W = 0.99966216 b = 0.016890889
Epoch 40 : cost = 4.455984e-05 W = 0.9996746 b = 0.016268639
Epoch 50 : cost = 4.1336203e-05 W = 0.9996866 b = 0.015669275
Epoch 60 : cost = 3.8349073e-05 W = 0.99969816 b = 0.015091987
Epoch 70 : cost = 3.5585737e-05 W = 0.99970937 b = 0.014535968
Epoch 80 : cost = 3.2999385e-05 W = 0.99972 b = 0.014000463
Epoch 90 : cost = 3.062137e-05 W = 0.99973035 b = 0.013484693
Epoch 100 : cost = 2.8402981e-05 W = 0.99974024 b = 0.0129879005
Training Cost = 2.8402981e-05 Weight = 0.99974024 bias = 0.0129879005
预测结果: 0.022985302954912187
保存模型
1.3 模型的加载示例代码(Python)
在Python环境下,加载PB文件也是非常方便,使用tf.saved_model.loader.load
一行代码就完成加载了,使用方法一或者方法二都可以。
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tag_constants.SERVING], "./result/model_save")
# 方法1:
# X = graph.get_tensor_by_name("X:0")
# y_pred = graph.get_tensor_by_name("y_pred:0")
# print(sess.run(y_pred, feed_dict={ X:0.01}))
# 方法2:
print(sess.run("y_pred:0", feed_dict={"X:0": 0.01}))
从打印的结果看,预测结果与训练结束时的结果一致,符合预期
##结果一致
0.022985302
1.4 模型的加载示例代码(Java)
在工业界中,很多公司是以Java为主要开发语言,因为考虑到应用部署的语言兼容性,我们这里也提供了Java环境下的模型加载方式。
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import java.io.IOException;
import java.nio.FloatBuffer;
public class model_restore_sample {
public static void main(String[] args) throws IOException {
int num_predict = 1;
String path = "../tensorflow_model_save_restore/result/model_save";
try(SavedModelBundle savedModelBundle = SavedModelBundle.load(path, "serve");){
//创建Session
Session sess = savedModelBundle.session();
//创建Input Tensor,输入新样本的值是0.01,shape为1
Tensor X = Tensor.create(new long[] {num_predict},
FloatBuffer.wrap(new float[] {0.01f})
);
//输入Input,获得预测结果
float[] result = sess.runner().feed("X", X).fetch("y_pred").run().get(0).copyTo(new float[num_predict]);
//打印结果
System.out.println(result[0]);
}
}
}
结果与Python版本的预测结果一致
0.022985302
2.总结
本文介绍了如何使用ModelBuilder API去保存Tensorflow模型,并提供了Python版本和Java版本的模型加载方式,这些方式更简单而且兼容性更强。也许有的小伙伴会问,既然模型文件已经是固定的话,能否将其做成一个服务,直接用于线上预测呢?答案当然是可以的,后续的文章会给大家介绍一种工业界非常流行的模型部署方式Tensorflow Serving
,可以支持大规模线上的调用。