tensorflow模型保存后继续训练_TensorFlow 训练模型的保存&加载

什么是Tensorflow的模型

Tensorflow的模型主要包括神经网络的架构设计(或者称为计算图的设计)和已经训练好的网络参数。因此,Tensorflow模型包括的主要文件:

“.meta”:包含了计算图的结构

“.data”:包含了变量的值

“.index”:确认checkpoint

“checkpiont”:一个protocol buffer,包含了最近的一些checkpoints

存储一个Tensorflow的模型

当我们训练的神经网络模型的损失函数或者精度收敛时,我们需要把参数或者网络结构存储起来。如果我们想要存储整个网络结构和该网络的所有参数,我们需要创建一个tf.train.Saver()的实例。Tensorflow变量的作用域仅在Session内部。因此,我们必须在一个Session的内部存储有关的数据。

saver.save(sess,'my_test_model')

sess是我们创建的一个Session实例,my_test_model是我们给模型的命名。

具体的实例:

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')

w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')

saver = tf.train.Saver()

sess = tf.Session()

sess.run(tf.global_variables_initializer())

saver.save(sess, './my_test_model')

sess.close()

执行上述语句,我们会同级目录下看到新增的文件:

my_test_model.data-00000-of-00001

my_test_model.index

my_test_model.meta

如果网络架构更改了,Tensorflow会重写上述的文件。

如果我们想要每1000步保存一次,那么需要更改语句:

saver.save(sess, 'my_test_model', global_step=1000)

那么当训练时,我们会每1000次迭代存储一次模型。.meta会在第一次到达1000次迭代时创建,之后的每千步,就不需要在重新创建.meta文件了。只要图的架构 不更改,就不需要重新创建.meta文件。 如果不写步数,默认每次迭代保存一次。

如果我们要仅仅保留最近4次创建的模型,并且每两个小时存储一次模型,可以进行下面的操作:

# saves a model every 2 hours and maximum 4 latest models are saved.

saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

如果我们在tf.train.Saver()中不指定任何参数,那么Tensorflow会默认保存所有的变量。假设我们只想保留部分变量或者collection,那么需要显式地表明需要保留的对象。当创建tf.train.Saver()对象时,使用一个包含有关变量的list或者字典声明。比如:

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')

w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')

saver = tf.train.Saver([w1, w2])

sess = tf.Session()

sess.run(tf.global_variables_initializer())

saver.save(sess, './my_test_model')

sess.close()

导入一个训练好的模型

如果我们要导入一个训练好的模型,需要做以下两步:

创建一个网络

使用函数:

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

把存储在my_test_model-1000.meta加载到saver当中。这个操作知识会把在.meta文件中定义的网络追加到当前网络的后面,我们仍然需要加载原来网络的参数数值。

加载参数

操作如下:

with tf.Session() as sess:

new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')

new_saver.restore(sess, tf.train.lasters_checkpoint('./'))

在这之后,w1和w2的数据就会被重新加载进来。

对导入的模型进行的操作

现在,学着加载模型,把模型用于预测、训练甚至更改模型的架构。现在构造一个简单的网络模型,保存并重新导入。注意一点:tf.placeholder的数据不会被保存 !!!!

先定义训练文件:

import tensorflow as tf

# 定义用于恢复变量的例子

w1 = tf.placeholder(dtype=tf.float32, name="w1")

w2 = tf.placeholder(dtype=tf.float32, name="w2")

b1 = tf.Variable(2.0, name="bias")

feed_dict = {w1: 4, w2: 8}

# 定义用于恢复操作的例子 w4=w3*b1,w3=(w1+w2)*b1

w3 = tf.add(w1, w2, name="part_op")

w4 = tf.multiply(w3, b1, name="op_to_restore")

sess = tf.Session()

sess.run(tf.global_variables_initializer()) # 时刻记着,要初始化

saver = tf.train.Saver()

print(sess.run(w4, feed_dict)) # 24.0

saver.save(sess, './my_test_model', global_step=1000)

sess.close()

定义加载文件:

import tensorflow as tf

sess = tf.Session()

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()

w1 = graph.get_tensor_by_name("w1:0")

w2 = graph.get_tensor_by_name("w2:0")

feed_dict = {w1: 13.0, w2: 17.0}

# w4=w3*b1,w3=(w1+w2)*b1

op_to_restore = graph.get_tensor_by_name("op_to_restore:0") # 60.0

print(sess.run(op_to_restore, feed_dict))

sess.close()

当导入模型的时候,不但需要恢复计算图和相关的参数,而且需要重新对tf.placeholder喂数据。通过graph.get_tensor_by_name获取保存的操作和占位符。如果我们想要使用网络计算,仅需要给不同的占位符添加不同的数据即可。

如果我们想要对原来的网络添加更多的层数并接着训练它,可以按照下面的步骤处理:

import tensorflow as tf

sess = tf.Session()

# 恢复计算图

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

saver.restore(sess, tf.train.latest_checkpoint('./'))

# 获取占位符

graph = tf.get_default_graph()

w1 = graph.get_tensor_by_name("w1:0")

w2 = graph.get_tensor_by_name("w2:0")

feed_dict = {w1: 13.0, w2: 17.0}

# 恢复操作

op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

# 增加新的操作

add_on_op = tf.multiply(op_to_restore, 2.0)

# 别忘了喂数据

print(sess.run(add_on_op, feed_dict))

sess.close()

由此可以看出,只需要把原来的操作加载完毕后,当成一个输出数据接入新的网络即可。

也可以把原来网络的一部分加载 到新的网络中,比如下面的操作:

先更改之前的一行代码

w3 = tf.add(w1, w2, name="part_op")

加载操作:

import tensorflow as tf

sess = tf.Session()

saver = tf.train.import_meta_graph("my_test_model-1000.meta")

saver.restore(sess, tf.train.latest_checkpoint('./'))

graph = tf.get_default_graph()

w1 = graph.get_tensor_by_name("w1:0")

w2 = graph.get_tensor_by_name("w2:0")

feed_dict = {w1: 13.0, w2: 14.0}

w3 = graph.get_tensor_by_name("part_op:0")

op = tf.multiply(w3, 4)

print(sess.run(op, feed_dict)) # 108.0

sess.close()

使用SavedModel的格式

SavedMode类把Saver类进行了一个更高层的封装,开发效率可能会更高,但是暂时没有前一种方法常用。Saver类更看重对变量的封装, 而SavedModel更看重压缩封装保存所有有用的信息。

保存操作:

import tensorflow as tf

tf.reset_default_graph()

w1 = tf.Variable(1.0, name="w1")

w2 = tf.Variable(2.0, name="w2")

w3 = tf.multiply(w1, w2, name="w3")

builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/')

with tf.Session() as sess:

sess.run(tf.global_variables_initializer())

sess.run(w3)

builder.add_meta_graph_and_variables(sess,

[tf.saved_model.tag_constants.TRAINING],

signature_def_map=None,

assets_collection=None)

builder.save()

读取操作:

import tensorflow as tf

with tf.Session() as sess:

tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING],

'./SavedModel')

w1 = sess.run('w1:0')

w2 = sess.run('w2:0')

w3 = sess.run('w3:0')

print(w1, w2, w3)

Enjoy your coding!

你可能感兴趣的:(tensorflow模型保存后继续训练_TensorFlow 训练模型的保存&加载)