Tensorflow模型保存和恢复 meta,ckpt含义

转载:https://blog.csdn.net/sinat_34474705/article/details/78995196

什么是Tensorflow模型?

当你训练好一个神经网络后,你会想保存好你的模型便于以后使用并且用于生产。因此,什么是Tensorflow模型?Tensorflow模型主要包含网络设计(或者网络图)和训练好的网络参数的值。所以Tensorflow模型有两个主要的文件:

a) Meta图: 
Meta图是一个协议缓冲区(protocol buffer),它保存了完整的Tensorflow图;比如所有的变量、运算、集合等。这个文件的扩展名是.meta

b) Checkpoint 文件 
这是一个二进制文件,它保存了权重、偏置项、梯度以及其他所有的变量的取值,扩展名为.ckpt。但是, 从0.11版本开始,Tensorflow对改文件做了点修改,checkpoint文件不再是单个.ckpt文件,而是如下两个文件:

mymodel.data-00000-of-00001
mymodel.index
  • 1
  • 2

其中, .data文件包含了我们的训练变量。除此之外,还有一个叫checkpoint的文件,它保留了最新的checkpoint文件的记录。

总结一下,对于0.10之后的版本,tensorflow模型包含以下文件:

model files 
但对于0.11之前的版本,只包含三个文件:

inception_v1.meta
inception_v1.ckpt
checkpoin
  • 1
  • 2
  • 3

现在我们已经知道Tensorflow模型是什么样子的,让我们继续学习如何保存模型。

保存Tensorflow模型

假如你正在训练一个用于图像分类的卷积神经网络(training a convolutional neural network for image classification)。通常你会先观察损失和准确率,一旦发现网络收敛,就可以手动停止训练过程或者直接训练固定迭代次数。当训练完成后,我们想要保存所有的变量和网络图便于以后使用。因此在Tensorflow中, 为了保存网络图和所有参数的值,我们应该创建tf.train.Saver()这个类的一个对象。

saver = tf.train.Saver()
  • 1

记住Tensorflow变量只有在会话(session)中才能激活。因此,你需要在会话中调用你刚创建的对象的保存方法。

saver.save(sess, "my-test-model")
  • 1

这里,sess是一个session对象,“my-test-model”是你的模型名字。让我们看一个完整的例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

如果我们要在1000次迭代后保存模型,我们应该在调用保存方法时传入步数计数:

saver.save(sess, "my_test_model", global_step=1000)
  • 1

这会在模型名称后加一个“-1000”并且会创建如下文件:

my_test_model-1000.index
my_test_model-1000.meta
my_test_model-1000.data-00000-of-00001
checkpoint
  • 1
  • 2
  • 3
  • 4

假设在训练过程中,我们要每1000次迭代保存我们的模型,因此.meta文件会在第一次(1000次迭代)时创建,我们并不需要之后每1000次迭代都保存一遍这个文件(我们在2000,3000…迭代时都不需要保存这个文件,因为这个文件始终不变)。我们只需要保存这个模型供以后使用,因为模型图不会变化。所以,当我们不想重写meta图的时候,我们这样写:

saver.save(sess, "my-model", global_step=step, write_meta_graph=False)
  • 1

如果你只想保留4个最新的模型并且在训练过程中每过2小时保存一次模型,你可以使用max_to_keep和keep_checkpoint_every_n_hours,就像这样:

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)
  • 1
  • 2

注意,如果我们在tf.train.Saver()中不指定任何东西,它将保存所有的变量。要是我们不想保存所有的变量而只是一部分变量。我们可以指定我们想要保存的变量/集合。当创建tf.train.Saver()对象的时候,我们给它传递一个我们想要保存的变量的字典列表。我们来看一个例子:

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

当需要的时候,这个代码可以用来保存Tensorflow图中的指定部分。

导入预训练模型

如果你想要用其他人预训练的模型进行微调,需要做两件事:

a) 创建网络 
你可以写python代码来手动创建和原来一样的模型。但是,想想看,我们已经将原始网络保存在了.meta文件中,可以用tf.train.import()函数来重建网络:

saver = tf.train.import_meta_graph("my_test_model-1000.meta")
  • 1

记住,import_meta_graph函数将只将定义在.meta文件中的网络添加到当前的图上。因此,它虽然帮你创建了额图/网络,但我们还是需要导入我们在这个图上训练好的模型的参数。

b) 导入参数 
我们可以调用由tf.train.Saver()创建的对象saver中的restore方法来恢复网络中的参数。

with tf.Session() as sess:
  new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')
  new_saver.restore(sess, tf.train.latest_checkpoint('./'))
  • 1
  • 2
  • 3

这样,张量的值(如w1和w2)就被恢复并且可以访问了:

with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('my-model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    print(sess.run('w1:0'))
##Model has been restored. Above statement will print the saved value of w1.
  • 1
  • 2
  • 3
  • 4
  • 5

现在你已经理解了如何保存和导入Tensorflow模型。在下一节,我会介绍一个实际应用即导入任何预训练好的模型。

你可能感兴趣的:(Tensorflow模型保存和恢复 meta,ckpt含义)