什么是Tensorflow的模型
Tensorflow的模型主要包括神经网络的架构设计(或者称为计算图的设计)和已经训练好的网络参数。因此,Tensorflow模型包括的主要文件:
“.meta”:包含了计算图的结构
“.data”:包含了变量的值
“.index”:确认checkpoint
“checkpiont”:一个protocol buffer,包含了最近的一些checkpoints
存储一个Tensorflow的模型
当我们训练的神经网络模型的损失函数或者精度收敛时,我们需要把参数或者网络结构存储起来。如果我们想要存储整个网络结构和该网络的所有参数,我们需要创建一个tf.train.Saver()的实例。Tensorflow变量的作用域仅在Session内部。因此,我们必须在一个Session的内部存储有关的数据。
saver.save(sess,'my_test_model')
sess是我们创建的一个Session实例,my_test_model是我们给模型的命名。
具体的实例:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()
执行上述语句,我们会同级目录下看到新增的文件:
my_test_model.data-00000-of-00001
my_test_model.index
my_test_model.meta
如果网络架构更改了,Tensorflow会重写上述的文件。
如果我们想要每1000步保存一次,那么需要更改语句:
saver.save(sess, 'my_test_model', global_step=1000)
那么当训练时,我们会每1000次迭代存储一次模型。.meta会在第一次到达1000次迭代时创建,之后的每千步,就不需要在重新创建.meta文件了。只要图的架构 不更改,就不需要重新创建.meta文件。 如果不写步数,默认每次迭代保存一次。
如果我们要仅仅保留最近4次创建的模型,并且每两个小时存储一次模型,可以进行下面的操作:
# saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)
如果我们在tf.train.Saver()中不指定任何参数,那么Tensorflow会默认保存所有的变量。假设我们只想保留部分变量或者collection,那么需要显式地表明需要保留的对象。当创建tf.train.Saver()对象时,使用一个包含有关变量的list或者字典声明。比如:
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1, w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './my_test_model')
sess.close()
导入一个训练好的模型
如果我们要导入一个训练好的模型,需要做以下两步:
创建一个网络
使用函数:
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
把存储在my_test_model-1000.meta加载到saver当中。这个操作知识会把在.meta文件中定义的网络追加到当前网络的后面,我们仍然需要加载原来网络的参数数值。
加载参数
操作如下:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
new_saver.restore(sess, tf.train.lasters_checkpoint('./'))
在这之后,w1和w2的数据就会被重新加载进来。
对导入的模型进行的操作
现在,学着加载模型,把模型用于预测、训练甚至更改模型的架构。现在构造一个简单的网络模型,保存并重新导入。注意一点:tf.placeholder的数据不会被保存 !!!!
先定义训练文件:
import tensorflow as tf
# 定义用于恢复变量的例子
w1 = tf.placeholder(dtype=tf.float32, name="w1")
w2 = tf.placeholder(dtype=tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}
# 定义用于恢复操作的例子 w4=w3*b1,w3=(w1+w2)*b1
w3 = tf.add(w1, w2, name="part_op")
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer()) # 时刻记着,要初始化
saver = tf.train.Saver()
print(sess.run(w4, feed_dict)) # 24.0
saver.save(sess, './my_test_model', global_step=1000)
sess.close()
定义加载文件:
import tensorflow as tf
sess = tf.Session()
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
# w4=w3*b1,w3=(w1+w2)*b1
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") # 60.0
print(sess.run(op_to_restore, feed_dict))
sess.close()
当导入模型的时候,不但需要恢复计算图和相关的参数,而且需要重新对tf.placeholder喂数据。通过graph.get_tensor_by_name获取保存的操作和占位符。如果我们想要使用网络计算,仅需要给不同的占位符添加不同的数据即可。
如果我们想要对原来的网络添加更多的层数并接着训练它,可以按照下面的步骤处理:
import tensorflow as tf
sess = tf.Session()
# 恢复计算图
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
# 获取占位符
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}
# 恢复操作
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
# 增加新的操作
add_on_op = tf.multiply(op_to_restore, 2.0)
# 别忘了喂数据
print(sess.run(add_on_op, feed_dict))
sess.close()
由此可以看出,只需要把原来的操作加载完毕后,当成一个输出数据接入新的网络即可。
也可以把原来网络的一部分加载 到新的网络中,比如下面的操作:
先更改之前的一行代码
w3 = tf.add(w1, w2, name="part_op")
加载操作:
import tensorflow as tf
sess = tf.Session()
saver = tf.train.import_meta_graph("my_test_model-1000.meta")
saver.restore(sess, tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 14.0}
w3 = graph.get_tensor_by_name("part_op:0")
op = tf.multiply(w3, 4)
print(sess.run(op, feed_dict)) # 108.0
sess.close()
使用SavedModel的格式
SavedMode类把Saver类进行了一个更高层的封装,开发效率可能会更高,但是暂时没有前一种方法常用。Saver类更看重对变量的封装, 而SavedModel更看重压缩封装保存所有有用的信息。
保存操作:
import tensorflow as tf
tf.reset_default_graph()
w1 = tf.Variable(1.0, name="w1")
w2 = tf.Variable(2.0, name="w2")
w3 = tf.multiply(w1, w2, name="w3")
builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(w3)
builder.add_meta_graph_and_variables(sess,
[tf.saved_model.tag_constants.TRAINING],
signature_def_map=None,
assets_collection=None)
builder.save()
读取操作:
import tensorflow as tf
with tf.Session() as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING],
'./SavedModel')
w1 = sess.run('w1:0')
w2 = sess.run('w2:0')
w3 = sess.run('w3:0')
print(w1, w2, w3)
Enjoy your coding!