— 对于模型的保存和恢复,前文已经做了介绍,然而读者可能已经注意到,在设定的保存文件夹中有着4个不同的文件类型:
可以得知,根据需要每个文件类型都有其不同的用处,但是仅仅知道这些还不够,对于Tensorflow工作人员来说,需要更进一步了解不同文件所处的作用。
— 在介绍存储文件之前,先对Saver类进行一下解释。在不同的会话中,当需要将数据在硬盘上进行保存时,就可以使用Saver类。这个Saver构造类允许你去控制3个元素:
—Saver类可以处理图中元数据和变量数据的保存和恢复,而我们唯一需要做的是,告诉Saver类需要保存哪个图和哪些变量。在默认的情况下,Saver类能处理默认图中包含的所有变量。但是,我们也可以创建出很多的Saver类,去保存想要的任何子图。
—介绍完Saver类,对于模型存储来说,这里有4个文件类型,依次如下:
在对模型进行保存和恢复时,Saver类将保存于图像关联的任何元数据,这意味着加载元检查点还将恢复与图相关联的所有空变量、操作和集合。
现在抛开理论介绍而对模型进行恢复与处理。,由于Tensorflow将整体的“图”文件存储在meta后缀的文件中,而将权重存储在ckpt后缀的文件中,在其具体使用时,对于模型权重的注入则是根据相应的名称来进行,因此,如果需要对模型中不同的权重进行重新注入的话,那么第一步就是需要赋予不同的权重以名称。
with tf.variable_scope("var"):
self.a_val = tf.Variable(tf.random_normal([1]),name="a_val")
self.b_val = tf.Variable(tf.random_normal([1]),name="b_val")
这里首先使用了tf.variable_scope对域进行了定义,之后在定义域内对输入变量进行赋值。最终形成的名称为:
var/a_val
首先是对于线性回归类的定义,在前面已经说了,需要对不同的变量或者占位符以及不同的函数定义其在图中的名称,这里为了简便,只定义了变量和占位符的名称:
import tensorflow as tf
class LineRegModel:
def __init__(self):
with tf.variable_scope("var"):
self.a_val = tf.Variable(tf.random_normal([1]),name="a_val")
self.b_val = tf.Variable(tf.random_normal([1]),name="b_val")
self.x_input = tf.placeholder(tf.float32,name="input_placeholder")
self.y_label = tf.placeholder(tf.float32,name="result_placeholder")
self.y_output = tf.add(tf.multiply(self.x_input, self.a_val), self.b_val,name="output")
self.loss = tf.reduce_mean(tf.pow(self.y_output - self.y_label, 2))
def get_saver(self):
return tf.train.Saver()
def get_op(self):
return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
在程序中可以看到,这里对每个变量或占位符都设置了相应的名称,而对变量域又设置了对应的域名。
import tensorflow as tf
import numpy as np
import global_variable
import lineRegulation_model as model
train_x = np.random.rand(5)
train_y = 5 * train_x + 3.2 # y = 5 * x + 3
model = model.LineRegModel()
a_val = model.a_val
b_val = model.b_val
x_input = model.x_input
y_label = model.y_label
y_output = model.y_output
loss = model.loss
optimize = model.get_op()
saver = model.get_saver()
if __name__ == "__main__":
sess = tf.Session()
sess.run(tf.global_variables_initializer())
flag = True
epoch = 0
while flag:
epoch += 1
_ , loss_val = sess.run([optimize,loss],feed_dict={x_input:train_x,y_label:train_y})
if loss_val < 1e-6:
flag = False
print(a_val.eval(sess) , " ", b_val.eval(sess))
print("-----------%d-----------"%epoch)
print(a_val.op)
saver.save(sess,global_variable.save_path)
print("model save finished")
sess.close()
可以看到,其中的节点名被定义为“var/a_val”,这是类中被定义是赋予的变量名称。
对于模型的恢复来说,需要首先恢复模型的整个图文件,之后从图文件中读取相应的节点信息。
saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')
Saver方法先从图中获取了整个图的信息,之后根据节点名称将不同的变量或者占位符重新按名称赋值。
#读取placeholder和最终的输出结果
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')
input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')#最终输出结果的tensor
而具体的权重恢复则需要在对话中完成。
with tf.Session() as sess:
saver.restore(sess, './model/save_model.ckpt')
完整代码
import tensorflow as tf
saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')
#读取placeholder和最终的输出结果
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')
input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')
with tf.Session() as sess:
saver.restore(sess, './model/save_model.ckpt')
result = sess.run(y_output, feed_dict={input_placeholder: [1]})
print(result)
print(sess.run(a_val))
读者可能注意到,在程序中采用通过名称获取对应的变量值的时候,冒号的右边有一个0符号,这是在Tensorflow的图运行中为了进行参数的复用而使用的标记类型,这里读者可以对其忽略而直接使用,程序运行的结果如下:
如果要对模型的特定值进行恢复,同样可以使用这个首先载入图文件之后使用权重对其赋值的办法。
import tensorflow as tf
saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')
y_output=graph.get_tensor_by_name('output:0')
with tf.Session() as sess:
saver.restore(sess, './model/save_model.ckpt')
print(sess.run(a_val))
可以看到这里只定义了变量a_val,并通过相应的名称将其重新获取。这种方法可以获取到模型中特定的变量或者节点的值,其最终结果如下: