TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API就是tf.train.Saver类。以下代码给出了保存TesnsorFlow计算图的方法。
import tensorflow as tf #声明两个变量并计算他们的和 v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = "v1") v2 = tf.Variable(tf.constant(2.0, shape = [1]), name = "v2") result = v1 + v2 init_op = tf.initialize_all_variables() #声明tf.train.Saver类用于保存模型 saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) #将模型保存到/path/to/model/model.ckpt文件 saver.save(sess, "/path/to/model/model.ckpt")
上述代码实现了持久化一个简单的TensorFlow模型的功能。通过saver.save函数将TensorFlow模型保存到了指定路径。虽然该程序仅指定了一个文件路径,但是在这个文件目录下会出现三个文件。因为TensorFlow会将计算图的结构和图上的参数取值分开进行保存。
第一个文件为model.ckpt.meta,它保存了TensorFlow计算图的结构,可以简单理解为神经网络的网络结构。第二个文件为model.ckpt,这个文件中保存了TensorFlow程序中每一个变量的取值。最后一个文件为checkpoint文件,这个文件保存了一个目录下所有的模型文件列表。
下面是加载这个已经保存的TensorFlow模型的方法。
import tensorflow as tf #使用和保存模型代码中一样的方式来声明变量 v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = "v1") v2 = tf.Variable(tf.constant(2.0, shape = [1]), name = "v2") result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess: #加载已经保存的模型,并通过已经保存的模型中的变量的值计算加法 saver.restore(sess, "path/to/model/model.ckpt") print(sess.run(result))
两段代码唯一不同的是,在加载模型的代码没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
可以直接加载已经持久化的图。
import tensorflow as tf #直接加载持久化的图 saver = tf.train.import_meta_graph("/path/to/model.ckpt/model.ckpt.meta") with tf.Session() as sess: saver.restore(sess, "/path/to/medel/model.ckpt") #通过张量的名称获取张量 print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))
上面的程序中,默认保存和加载了TensorFlow计算图上定义的全部变量。但有时只需要保存或者加载部分变量,这时需要在声明tf.train.Saver类时提供一个列表指定需要保存或者加载的变量。
tf.train.Saver类同样支持在保存或者加载时给变量重命名。
#这里声明的变量名称和已经保存的模型中的变量的名称不同 v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = "other-v1") v2 = tf.Variable(tf.constant(2.0, shape = [1]), name = "other-v2") #如果直接使用tf.train.Saver()来加载模型会报变量找不到的错误。 #使用一个字典来重命名变量就可以加载原来的模型了。这个字典指定了原来名称为v1的变量心在加载 #到变量v1中(名称为other-v1),名称为v2的变量加载到变量v2中(名称为other-v2) saver = tf.train.Saver({"v1":v1, "v2": v2})
这样做的主要目的之一是方便使用变量的滑动平均值。在TensorFlow中,每一个变量的滑动平均值是通过影子变量维护的,所以要获取变量的滑动平均值实际上就是获取这个影子变量的取值。如果在加载模型是直接将影子变量映射到变量自身,那么在使用训练好的模型时就不需要在调用函数来获取变量的滑动平均值了。
以下为一个保存滑动平均值模型的样例。
import tensorflow as tf v= tf.Varibale(0, dtype = tf.float32, name = "v") #在没有申请华东平均模型时只有一个变量v,所以下面的语句只会输出“v:0”。 for variables in tf.all_variables(): print(variables.name) ema = tf.train.ExponentialMovingAverage(0.99) maintain_average_op = ema.apply(tf.all_variables()) #在申请滑动平均值模型之后,TensorFlow会自动生成一个影子变量 #v/ExponentialMoving Average。于是下面的语句会输出 #“v:0”和“v/ExpinentialMovingAverage:0”。 for variables in tf.all_variables(): print(variables.name) saver = tf.train.Saver() with tf.Session() as sess: init_op = tf.initialize_all_variables() sess.run(init_op) sess.run(tf.assign(v, 10)) sess.run(maintain_average_op) #保存时,TensorFlow会将v:0和v/ExponentialMovingAverage:0两个变量都存下来 saver.save(sess, "path/to/model.ckpt") print(sess.run([v, ema.average(v)])) #输出[10.0, 0.99999905]
以下代码给出了如何通过变量的重命名直接读取变量的滑动平均值。从下面程序的输出可以看到,读取的变量v的值实际上是上面代码中变量v的滑动平均值。通过这个方法,可以使用完全一样的代码来计算滑动平均模型前向传播的结果。
v = tf.Variable(0, dtype = tf.float32, name = "v") #通过变量重命名将原来变量v的滑动平均值赋值给v。 saver = tf.train.Saver({"v/ExponentialMovingAverage": v}) with tf.Session() as sess: saver.restore(sess, "/path/to/model/model.ckpt") print(sess.run(v))#输出0.099999905,这个值就是原来模型中变量v的滑动平均值
使用tf.train.Saver会保存运行TensorFlow程序所需要的全部信息,然而有时并不需要某些信息。于是TensorFlow提供了convert_variables_to_constnats函数,通过这个函数可以将计算图中的变量以及取值通过常量的形式保存,这样整个TensorFlow计算图可以统一存放在一个文件中。
import tensorflow as tf from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = "v1") v2 = tf.Variable(tf.constant(2.0, shape = [1]), name = "v2") result = v1 + v2 init_op = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init_op) #导出当前计算图的GraphDef部分,只需要这一部分就可以完成从输入层到输出层的计算过程 graph_def = tf.get_default_graph().as_graph_def() #将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉在下面一行代码中 #最后一个参数['add']给出了需要保存的节点名称,add节点是上面两个变量相加的 #操作 output_grpah_def = graph_util_convert_variables_to_constants(sess, graph_def, ['add']) #将导出的模型存入文件 with tf.gfile.GFile("/path/to/model/combined_model.pb", "wb") as f: f.write(output_grpah_def.SerializerToString())
通过下面的程序可以直接计算定义的加法运算的结果,当只需要得到计算图中某个节点的取值时 ,此处给了一个更加简便的方法。
import tensorflow as tf from tensorflow.python.platform import gfile with tf.Session() as sess: model_filename = "/path/to/model/combined_model.pb" #读取保存的模型文件,并将文件解析成对应的GraphDef Protocol Buffer with gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) #将graph_def中保存的图加载到当前的图中。return_elements = ["add:0"]给出了返回 #的张量的名称。在保存的时候给出的是计算节点的名称,所以为“add”。在加载的时候给出 #的是张量的名称,所以是add:0 result = tf.import_graph_def(graph_def, return_elements = ["add:0"]) print(sess.run(result))