TensorFlow 模型的保存与恢复

                       TensorFlow 模型的保存与恢复

TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见

https://zhuanlan.zhihu.com/p/32887066

CKPT,首先这种模型文件是依赖 TensorFlow 的,只能在其框架下使用;其次,在恢复模型之前还需要再定义一遍网络结构,然后才能把变量的值恢复到网络中;PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。

下面,我以mnist手写数据集用softmax回归为例,说明如何对训练好的模型进行保存与恢复。

1. 训练模型并保存为模型文件

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
 
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
sess = tf.InteractiveSession()
 
x = tf.placeholder("float", shape=[None, 784], name='input_x')  # 输入图像占位符
y_ = tf.placeholder("float", shape=[None, 10])  # 标签类别占位符
 
# 模型参数一般用Variable来表示
W = tf.Variable(tf.zeros([784, 10]), name='w')  # 权重W是一个784x10的矩阵(因为我们有784个特征和10个输出值)
b = tf.Variable(tf.zeros([10]), name='b')  # 偏置b是一个10维的向量(因为我们有10个分类)
 
sess.run(tf.initialize_all_variables())  # 变量需要通过seesion初始化后,才能在session中使用
# 使用Tensorflow提供的回归模型softmax,y代表输出,把向量化后的图片x和权重矩阵W相乘,加上偏置b,然后计算每个分类的softmax概率值
y = tf.nn.softmax(tf.matmul(x, W) + b, name='predict')
 
cross_entropy = - tf.reduce_sum(y_ * tf.log(y))  # 计算交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  # 梯度下降算法以0.01的学习速率最小化交叉熵
 
# tf.argmax返回某个tensor对象在某一维上的其数据最大值所在的索引值
# 下面这行返回一组布尔值如[True, False, True, True]
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
 
# 把布尔值转换成浮点数,然后取平均值,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
 
for i in range(1000):
    batch = mnist.train.next_batch(50)  # 每一步迭代加载50个训练样本,然后执行一次train_step
    sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})
    if i % 100 == 0:
        print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))  # 模型在测试数据集上面的正确率
 
# 随便从测试集中取一个例子做测试
print(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}))
print(sess.run(tf.argmax(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}), axis=1)))   # 预测结果
print(mnist.test.labels[15])    # 标签值

a.保存为ckpt格式的模型文件

saver = tf.train.Saver()
saver.save(sess, "save_path/file_name")

 生成的模型文件如下:

        

checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

b.保存为pb格式的模型文件

builder = tf.saved_model.builder.SavedModelBuilder('./model2')
builder.add_meta_graph_and_variables(sess, ["mytag"])
builder.save()

生成的模型文件如下:

        

运行结果:(第6个元素最大,表示数字5,说明预测正确)

0.2847
0.8778
0.8945
0.8972
0.9031
0.9015
0.9109
0.9007
0.8901
0.9061
[[  2.59213091e-04   1.70691292e-05   1.03438069e-04   1.55748194e-02
    2.95701193e-05   9.70679998e-01   7.14014686e-06   7.19119780e-05
    1.32500082e-02   6.82865766e-06]]
[5]
[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]

TensorFlow 模型的保存与恢复_第1张图片

2. 模型文件的恢复与使用

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import os
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
 
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
 
 
# pb模型的恢复
def restore_model_pb():
    sess = tf.Session()
    tf.saved_model.loader.load(sess, ['mytag'], os.getcwd() + '\model2')
    input_x = sess.graph.get_tensor_by_name('input_x:0')
    op = sess.graph.get_tensor_by_name('predict:0')
    print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))
    sess.close()
 
 
# ckpt模型的恢复
def restore_model_ckpt():
    sess = tf.Session()
    # 加载模型结构
    saver = tf.train.import_meta_graph('./save_path/file_name.meta')
    # 只需要指定目录就可以恢复所有变量信息
    saver.restore(sess, tf.train.latest_checkpoint('./save_path'))
    # 直接获取保存的变量
    print(sess.run('w:0'))
 
    input_x = sess.graph.get_tensor_by_name('input_x:0')
    # # 获取需要进行计算的operator
    op = sess.graph.get_tensor_by_name('predict:0')
    print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))
    sess.close()
 
restore_model_pb()
# 打印所有变量的值
# print_tensors_in_checkpoint_file("save_path/file_name", None, True)

运行结果:

[[  2.59213091e-04   1.70691292e-05   1.03438069e-04   1.55748194e-02
    2.95701193e-05   9.70679998e-01   7.14014686e-06   7.19119780e-05
    1.32500082e-02   6.82865766e-06]]


java中调用以上模型文件请参考:java调用tensorflow模型文件。

【转载】:https://blog.csdn.net/ling913/article/details/80185535

                  https://blog.csdn.net/marsjhao/article/details/72829635

                  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/session_bundle/README.md

                  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md

                  https://blog.csdn.net/rabbit_judy/article/details/80054085

                  https://github.com/ZhuanZhiCode/TensorFlow-Java-Examples

 

你可能感兴趣的:(Python,Deep,Learning)