TensorFlow 模型的保存与恢复

TensorFlow目前保存的模型文件主要有两种,ckpt与pb,二者之间的异同请见

https://zhuanlan.zhihu.com/p/32887066

下面,我以mnist手写数据集用softmax回归为例,说明如何对训练好的模型进行保存与恢复。

1. 训练模型并保存为模型文件

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
sess = tf.InteractiveSession()

x = tf.placeholder("float", shape=[None, 784], name='input_x')  # 输入图像占位符
y_ = tf.placeholder("float", shape=[None, 10])  # 标签类别占位符

# 模型参数一般用Variable来表示
W = tf.Variable(tf.zeros([784, 10]), name='w')  # 权重W是一个784x10的矩阵(因为我们有784个特征和10个输出值)
b = tf.Variable(tf.zeros([10]), name='b')  # 偏置b是一个10维的向量(因为我们有10个分类)

sess.run(tf.initialize_all_variables())  # 变量需要通过seesion初始化后,才能在session中使用
# 使用Tensorflow提供的回归模型softmax,y代表输出,把向量化后的图片x和权重矩阵W相乘,加上偏置b,然后计算每个分类的softmax概率值
y = tf.nn.softmax(tf.matmul(x, W) + b, name='predict')

cross_entropy = - tf.reduce_sum(y_ * tf.log(y))  # 计算交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  # 梯度下降算法以0.01的学习速率最小化交叉熵

# tf.argmax返回某个tensor对象在某一维上的其数据最大值所在的索引值
# 下面这行返回一组布尔值如[True, False, True, True]
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

# 把布尔值转换成浮点数,然后取平均值,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

for i in range(1000):
    batch = mnist.train.next_batch(50)  # 每一步迭代加载50个训练样本,然后执行一次train_step
    sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})
    if i % 100 == 0:
        print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))  # 模型在测试数据集上面的正确率

# 随便从测试集中取一个例子做测试
print(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}))
print(sess.run(tf.argmax(sess.run(y, feed_dict={x: np.expand_dims(mnist.test.images[15], axis=0)}), axis=1)))   # 预测结果
print(mnist.test.labels[15])    # 标签值

对上述代码不熟悉的请参考:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_pros.html

a.保存为ckpt格式的模型文件

saver = tf.train.Saver()
saver.save(sess, "save_path/file_name")

        生成的模型文件如下:

        TensorFlow 模型的保存与恢复_第1张图片

b.保存为pb格式的模型文件

builder = tf.saved_model.builder.SavedModelBuilder('./model2')
builder.add_meta_graph_and_variables(sess, ["mytag"])
builder.save()

         生成的模型文件如下:

        TensorFlow 模型的保存与恢复_第2张图片

运行结果:(第6个元素最大,表示数字5,说明预测正确)

0.2847
0.8778
0.8945
0.8972
0.9031
0.9015
0.9109
0.9007
0.8901
0.9061
[[  2.59213091e-04   1.70691292e-05   1.03438069e-04   1.55748194e-02
    2.95701193e-05   9.70679998e-01   7.14014686e-06   7.19119780e-05
    1.32500082e-02   6.82865766e-06]]
[5]
[ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]


2. 模型文件的恢复与使用

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import os
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


# pb模型的恢复
def restore_model_pb():
    sess = tf.Session()
    tf.saved_model.loader.load(sess, ['mytag'], os.getcwd() + '\model2')
    input_x = sess.graph.get_tensor_by_name('input_x:0')
    op = sess.graph.get_tensor_by_name('predict:0')
    print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))
    sess.close()


# ckpt模型的恢复
def restore_model_ckpt():
    sess = tf.Session()
    # 加载模型结构
    saver = tf.train.import_meta_graph('./save_path/file_name.meta')
    # 只需要指定目录就可以恢复所有变量信息
    saver.restore(sess, tf.train.latest_checkpoint('./save_path'))
    # 直接获取保存的变量
    print(sess.run('w:0'))

    input_x = sess.graph.get_tensor_by_name('input_x:0')
    # # 获取需要进行计算的operator
    op = sess.graph.get_tensor_by_name('predict:0')
    print(sess.run(op, feed_dict={input_x: np.expand_dims(mnist.test.images[15], axis=0)}))
    sess.close()

restore_model_pb()
# 打印所有变量的值
# print_tensors_in_checkpoint_file("save_path/file_name", None, True)

运行结果:

[[  2.59213091e-04   1.70691292e-05   1.03438069e-04   1.55748194e-02
    2.95701193e-05   9.70679998e-01   7.14014686e-06   7.19119780e-05
    1.32500082e-02   6.82865766e-06]]


java中调用以上模型文件请参考:java调用tensorflow模型文件






你可能感兴趣的:(Python,tensorflow)