Tensorflow 模型参数保存加载

通过tf.train.Saver类实现保存模型

import tensorflow as tf
a = tf.Variable(tf.constant(1.0, shape = [1]), name = 'a')
b = tf.Variable(tf.constant(1.0, shape = [1]), name = 'b')
c = tf.multiply(a, b)

saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, "Model/model.ckpt")

通过tf.train.Saver类实现恢复模型

import tensorflow as tf
a = tf.Variable(tf.constant(1.0, shape = [1]), name = 'a')
b = tf.Variable(tf.constant(2.0, shape = [1]), name = 'b')
c = tf.multiply(a, b)

saver = tf.train.Saver()
sess = tf.Session()
saver.restore(sess, './Model/model.ckpt')
print(sess.run(c)) #[2.]

通过tf.train.Saver类实现恢复模型,支持在加载时给变量重命名

import tensorflow
d = tf.Variable(tf.constant(0.0, shape = [1]), name = 'dd')
e = tf.Variable(tf.constant(0.0, shape = [1]), name = 'ee')
c  = tf.multiply(d, e)

saver = tf.train.Saver({'a':d, 'b':e})
sess = tf.Session()
saver.restore(sess, './Model/model.ckpt')
print(sess.run(c)) #[2.]

通过tf.train.Saver类实现恢复部分模型,比如不要Resnet50最后一层fc

variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude = ['resnet50/fc'])
saver = tf.train.Saver(variables_to_restore)
sess = tf.Session()
saver.restore(sess, './Model/Resnet.ckpt')

通过tf.train.import_meta_graph()恢复模型,无需重复定义计算图。

import tensorflow as tf
saver = tf.train.import_meta_graph("Model/model.meta")
sess = tf.Session()
saver.restore(sess, "./Model/model.ckpt")
graph = tf.get_default_graph()
c = graph.get_tensor_by_name("c:0")
print(sess.run(c)) #[2.]

还能在之前模型的基础上增加自己的计算图层。们用meta图导入了一个预训练的TextCNN网络,然后将最后一层的输出个数改成2用于微调新的数据

import tensorflow as tf
saver = tf.train.import_meta_graph("Model/TextCNN.meta")
sess = tf.Session()
saver.restore(sess, "./Model/model.ckpt")
graph = tf.get_default_graph()
tf_x = graph.get_tensor_by_name('input_x:0')
tf_y = graph.get_tensor_by_name('input_y:0')
dense = graph.get_tensor_by_name('dense:0')
dense = tf.stop_gradient(dense) #因为只想训练最后一层,所以在这里要停止梯度后向传播
logist = tf.layers.dense(dense, 2)
pred = tf.argmax(tf.nn.softmax(logist),1)

通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于pb文件中

import tensorflow as tf
from tensorflow.python.framework import graph_util
a = tf.Variable(tf.constant(1.0, shape = [1]), name = 'a')
b = tf.Variable(tf.constant(1.0, shape = [1]), name = 'b')
c = tf.multiply(a, b)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['multiply'])
with tf.gfile.GFile('Model/model.pb', 'wb') as f:
    f.write(output_grapg_def.SerializeToString())

从pb文件中恢复模型,并实现预测

import tensorflow as tf
from tensorflow.python.platform import gfile
tf_x = tf.placeholder(tf.float32, shape = [None, None], name = 'x')
sess = tf.Session()
with gfile.FastGFile('Model/model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    output = tf.import_graph_def(graph_def, input_map = {'x:0': tf_x},  return_elements=['pred:0'])
pred = sess.run(output, feed_dict = {tf_x: x})

你可能感兴趣的:(Tensorflow 模型参数保存加载)