之前训练的网络中有一部分可以用到一个新的网络中,但是不知道存储的参数如何部分恢复到新的网络中,也了解到有许多网络是通过利用一些现有的网络结构,通过finetuning进行改造实现的,因此了解了一下关于模型预训练后部分参数restore和finetuning的内容
更多内容参见:
https://blog.csdn.net/mieleizhi0522/article/details/80535189
https://blog.csdn.net/leo_xu06/article/details/79200634
https://blog.csdn.net/b876144622/article/details/79962727
https://blog.csdn.net/ying86615791/article/details/76215363
首先了解一下变量(tf.Variable),变量是tf框架中用于存储参数的对象,我们这里要恢复的参数也是variable类型的。训练的参数是放在不同名字下的variable中的,checkpoint中存储的变量也是通过不同的名字进行区分的,这里如果要恢复指定的参数可以使用
with tf.variable_scope('', reuse = True):
sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))
Saver是用于保存变量的对象。下面是saver对象的创建和调用
saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")
如果仅在session开始时恢复模型变量的一个子集,需要对剩下的变量执行初始化op。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})
对已有checkpoint内容进行查看,可以使用一下代码(来自https://blog.csdn.net/mieleizhi0522/article/details/80535189),然后就可以结合之前的指定变量名的方法对参数进行restore了。注意,在完成部分参数的restore后要记得对没有初始化的变量进行初始化,否则报错。
import tensorflow as tf
import os
from tensorflow.python import pywrap_tensorflow
model_dir=r'G:\KeTi\C3D'
checkpoint_path = os.path.join(model_dir, "sports1m_finetuning_ucf101.model")
# 从checkpoint中读出数据
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
# 输出权重tensor名字和值
for key in var_to_shape_map:
print("tensor_name: ", key,reader.get_tensor(key).shape)
输出
tensor_name: var_name/wc4a (3, 3, 3, 256, 512)
tensor_name: var_name/wc3a (3, 3, 3, 128, 256)
tensor_name: var_name/wd1 (8192, 4096)
tensor_name: var_name/wc5b (3, 3, 3, 512, 512)
tensor_name: var_name/bd1 (4096,)
tensor_name: var_name/wd2 (4096, 4096)
tensor_name: var_name/wout (4096, 101)
tensor_name: var_name/wc1 (3, 3, 3, 3, 64)
tensor_name: var_name/bc4b (512,)
tensor_name: var_name/wc2 (3, 3, 3, 64, 128)
tensor_name: var_name/bc3a (256,)
tensor_name: var_name/bd2 (4096,)
tensor_name: var_name/bc5a (512,)
tensor_name: var_name/bc2 (128,)
tensor_name: var_name/bc5b (512,)
tensor_name: var_name/bout (101,)
tensor_name: var_name/bc4a (512,)
tensor_name: var_name/bc3b (256,)
tensor_name: var_name/wc4b (3, 3, 3, 512, 512)
tensor_name: var_name/bc1 (64,)
tensor_name: var_name/wc3b (3, 3, 3, 256, 256)
tensor_name: var_name/wc5a (3, 3, 3, 512, 512)