TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见
https://zhuanlan.zhihu.com/p/32887066
CKPT,首先这种模型文件是依赖 TensorFlow 的,只能在其框架下使用;其次,在恢复模型之前还需要再定义一遍网络结构,然后才能把变量的值恢复到网络中;PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。
下面,我以mnist手写数据集用softmax回归为例,说明如何对训练好的模型进行保存与恢复。
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
sess = tf.InteractiveSession()
x = tf.placeholder("float", shape=[None, 784], name='input_x') # 输入图像占位符
y_ = tf.placeholder("float", shape=[None, 10]) # 标签类别占位符
# 模型参数一般用Variable来表示
W = tf.Variable(tf.zeros([784, 10]), name='w') # 权重W是一个784x10的矩阵(因为我们有784个特征和10个输出值)
b = tf.Variable(tf.zeros([10]), name='b') # 偏置b是一个10维的向量(因为我们有10个分类)
sess.run(tf.initialize_all_variables()) # 变量需要通过seesion初始化后,才能在session中使用
# 使用Tensorflow提供的回归模型softmax,y代表输出,把向量化后的图片x和权重矩阵W相乘,加上偏置b,然后计算每个分类的softmax概率值
y = tf.nn.softmax(tf.matmul(x, W) + b, name='predict')
cross_entropy = - tf.reduce_sum(y_ * tf.log(y)) # 计算交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) # 梯度下降算法以0.01的学习速率最小化交叉熵
# tf.argmax返回某个tensor对象在某一维上的其数据最大值所在的索引值
# 下面这行返回一组布尔值如[True, False, True, True]
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
# 把布尔值转换成浮点数,然后取平均值,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
for i in range(1000):
batch = mnist.train.next_batch(50) # 每一步迭代加载50个训练样本,然后执行一次train_step
sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})
if i % 100 == 0:
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) # 模型在测试数据集上面的正确率
# 随便从测试集中取一个例子做测试
print(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}))
print(sess.run(tf.argmax(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}), axis=1))) # 预测结果
print(mnist.test.labels[15]) # 标签值
a.保存为ckpt格式的模型文件
saver = tf.train.Saver()
saver.save(sess, "save_path/file_name")
生成的模型文件如下:
checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。
b.保存为pb格式的模型文件
builder = tf.saved_model.builder.SavedModelBuilder('./model2')
builder.add_meta_graph_and_variables(sess, ["mytag"])
builder.save()
生成的模型文件如下:
运行结果:(第6个元素最大,表示数字5,说明预测正确)
0.2847
0.8778
0.8945
0.8972
0.9031
0.9015
0.9109
0.9007
0.8901
0.9061
[[ 2.59213091e-04 1.70691292e-05 1.03438069e-04 1.55748194e-02
2.95701193e-05 9.70679998e-01 7.14014686e-06 7.19119780e-05
1.32500082e-02 6.82865766e-06]]
[5]
[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import os
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# pb模型的恢复
def restore_model_pb():
sess = tf.Session()
tf.saved_model.loader.load(sess, ['mytag'], os.getcwd() + '\model2')
input_x = sess.graph.get_tensor_by_name('input_x:0')
op = sess.graph.get_tensor_by_name('predict:0')
print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))
sess.close()
# ckpt模型的恢复
def restore_model_ckpt():
sess = tf.Session()
# 加载模型结构
saver = tf.train.import_meta_graph('./save_path/file_name.meta')
# 只需要指定目录就可以恢复所有变量信息
saver.restore(sess, tf.train.latest_checkpoint('./save_path'))
# 直接获取保存的变量
print(sess.run('w:0'))
input_x = sess.graph.get_tensor_by_name('input_x:0')
# # 获取需要进行计算的operator
op = sess.graph.get_tensor_by_name('predict:0')
print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))
sess.close()
restore_model_pb()
# 打印所有变量的值
# print_tensors_in_checkpoint_file("save_path/file_name", None, True)
运行结果:
[[ 2.59213091e-04 1.70691292e-05 1.03438069e-04 1.55748194e-02
2.95701193e-05 9.70679998e-01 7.14014686e-06 7.19119780e-05
1.32500082e-02 6.82865766e-06]]
java中调用以上模型文件请参考:java调用tensorflow模型文件。
【转载】:https://blog.csdn.net/ling913/article/details/80185535
https://blog.csdn.net/marsjhao/article/details/72829635
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/README.md
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
https://blog.csdn.net/rabbit_judy/article/details/80054085
https://github.com/ZhuanZhiCode/TensorFlow-Java-Examples