我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow提供了两种保存模型的方式,一种是使用tf.train.Saver函数来保存TensorFlow程序的参数和完整的模型结构,保存的文件后缀为 “.ckpt”;另一种方式是将计算图保存在一个 “.pb” 文件中。
在创建Saver对象时,有一个参数我们经常会用到,就是max_to_keep参数,这个是用来设置保存模型的个数,默认为5,即max_to_keep=5,保存最近的5个模型。如果想每训练一个epoch就保存一次模型,则可以将max_to_keep设置为None或0,如:
saver = tf.train.Saver(max_to_keep=None)
在创建完saver对象后,就可以保存训练好的模型了,如:
saver.save(sess, 'ckpt/mnist.ckpt', global_step=global_step)
第三个参数将训练的次数作为后缀加入到模型中:
saver.save(sess, 'my_model', global_step=0) # filename: my-model-0
saver.save(sess, 'my_model', global_step=1000) # filename: my-model-1000
其中需要注意的是:
tf.train.import_meta_graph() 方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。
有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。
我们知道,graph_def文件中没有包含网络中的Variable值(通常情况权重定义为Variable),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。
在保存模型之前,需要先将网络的权重进行冻结,使用:
tensorflow.python.framework.graph_util.convert_variables_to_constant()
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')
# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 这里需要填入输出tensor的名字
constant_graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
with tf.gfile.FastGFile(pd_file_path, mode = 'wb') as f:
f.write(constant_graph.SerializeToString())
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
pb_file_path="catdog.pb"
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read()) #rb
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
tf.global_variables_initializer().run()
input_x = sess.graph.get_tensor_by_name("input:0") ####这就是刚才取名的原因
print (input_x)
out_softmax = sess.graph.get_tensor_by_name("softmax:0")
print (out_softmax)
out_label = sess.graph.get_tensor_by_name("output:0")
print (out_label)
img = np.array(Image.open(jpg_path).convert('L'))
img = transform.resize(img, (H, W, 3))
plt.figure("fig1")
plt.imshow(img)
img = img * (1.0 /255)
img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, H, W, 3])})