saver实例代码:
import tensorflow as tf
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
init= tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, "/home/violet/aitest/ResNet/logs/test/save_net.ckpt")
print("Save to path: ", save_path)
restore实例代码:
# restore variables
# redefine the same shape and same type for your variables
import tensorflow as tf
import numpy as np
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
# not need init step
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "/home/violet/aitest/ResNet/logs/test/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
一次 saver.save() 后可以在文件夹中看到新增的四个文件:
checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。简单理解就是权重等参数被保存到 .chkp.data 文件中,以字典的形式;图和元数据被保存到 .chkp.meta 文件中,可以被 tf.train.import_meta_graph 加载到当前默认的图。
根据已有模型进行微调
(1)利用tf.train.Saver()从checkpoint恢复模型
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to restore all the variables.
restorer = tf.train.Saver()
# Add ops to restore some variables.
restorer = tf.train.Saver([v1, v2])
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
restorer.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Do some work with the model
...
(2)部分恢复模型参数
# Create some variables.
v1 = slim.variable(name="v1", ...)
v2 = slim.variable(name="nested/v2", ...)
...
# Get list of variables to restore (which contains only 'v2'). These are all
# equivalent methods:
variables_to_restore = slim.get_variables_by_name("v2")
# or
variables_to_restore = slim.get_variables_by_suffix("2")
# or
variables_to_restore = slim.get_variables(scope="nested")
# or
variables_to_restore = slim.get_variables_to_restore(include=["nested"])
# or
variables_to_restore = slim.get_variables_to_restore(exclude=["v1"])
# Create the saver which will be used to restore the variables.
restorer = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
# Restore variables from disk.
restorer.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Do some work with the model
(3)当图的变量名与checkpoint中的变量名不同时,恢复模型参数
当从checkpoint文件中恢复变量时,Saver在checkpoint文件中定位到变量名,并且把它们映射到当前图中的变量中。之前的例子中,我们创建了Saver,并为其提供了变量列表作为参数。这时,在checkpoint文件中定位的变量名,是隐含地从每个作为参数给出的变量的var.op.name而获得的。这一方式在图与checkpoint文件中变量名字相同时,可以很好的工作。而当名字不同时,必须给Saver提供一个将checkpoint文件中的变量名映射到图中的每个变量的字典,例子见下:
# Assuming that 'conv1/weights' should be restored from 'vgg16/conv1/weights'
def name_in_checkpoint(var):
return 'vgg16/' + var.op.name
# Assuming that 'conv1/weights' and 'conv1/bias' should be restored from 'conv1/params1' and 'conv1/params2'
def name_in_checkpoint(var):
if "weights" in var.op.name:
return var.op.name.replace("weights", "params1")
if "bias" in var.op.name:
return var.op.name.replace("bias", "params2")
variables_to_restore = slim.get_model_variables()
variables_to_restore = {name_in_checkpoint(var):var for var in variables_to_restore}
restorer = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
# Restore variables from disk.
restorer.restore(sess, "/tmp/model.ckpt")
(4)在一个不同的任务上对网络进行微调
比如我们要将1000类的imagenet分类任务应用于20类的Pascal VOC分类任务中,我们只导入部分层,见下例:
image, label = MyPascalVocDataLoader(...)
images, labels = tf.train.batch([image, label], batch_size=32)
# Create the model
predictions = vgg.vgg_16(images)
train_op = slim.learning.create_train_op(...)
# Specify where the Model, trained on ImageNet, was saved.
model_path = '/path/to/pre_trained_on_imagenet.checkpoint'
# Specify where the new model will live:
log_dir = '/path/to/my_pascal_model_dir/'
# Restore only the convolutional layers:
variables_to_restore = slim.get_variables_to_restore(exclude=['fc6', 'fc7', 'fc8'])
init_fn = assign_from_checkpoint_fn(model_path, variables_to_restore)
# Start training.
slim.learning.train(train_op, log_dir, init_fn=init_fn)
原文链接:
https://www.cnblogs.com/bmsl/p/dongbin_bmsl_01.html