tensorflow 模型预训练后的参数restore finetuning

之前训练的网络中有一部分可以用到一个新的网络中,但是不知道存储的参数如何部分恢复到新的网络中,也了解到有许多网络是通过利用一些现有的网络结构,通过finetuning进行改造实现的,因此了解了一下关于模型预训练后部分参数restore和finetuning的内容

更多内容参见:

https://blog.csdn.net/mieleizhi0522/article/details/80535189

https://blog.csdn.net/leo_xu06/article/details/79200634

https://blog.csdn.net/b876144622/article/details/79962727

https://blog.csdn.net/ying86615791/article/details/76215363

首先了解一下变量(tf.Variable),变量是tf框架中用于存储参数的对象,我们这里要恢复的参数也是variable类型的。训练的参数是放在不同名字下的variable中的,checkpoint中存储的变量也是通过不同的名字进行区分的,这里如果要恢复指定的参数可以使用

with tf.variable_scope('', reuse = True):
        sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))

Saver是用于保存变量的对象。下面是saver对象的创建和调用

saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")

 如果仅在session开始时恢复模型变量的一个子集,需要对剩下的变量执行初始化op。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})

对已有checkpoint内容进行查看,可以使用一下代码(来自https://blog.csdn.net/mieleizhi0522/article/details/80535189),然后就可以结合之前的指定变量名的方法对参数进行restore了。注意,在完成部分参数的restore后要记得对没有初始化的变量进行初始化,否则报错。


    import tensorflow as tf

    import os

    from tensorflow.python import pywrap_tensorflow

    model_dir=r'G:\KeTi\C3D'

    checkpoint_path = os.path.join(model_dir, "sports1m_finetuning_ucf101.model")

    # 从checkpoint中读出数据

    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)

    # reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法

    var_to_shape_map = reader.get_variable_to_shape_map()

    # 输出权重tensor名字和值

    for key in var_to_shape_map:

    print("tensor_name: ", key,reader.get_tensor(key).shape)

输出

tensor_name: var_name/wc4a (3, 3, 3, 256, 512)

tensor_name: var_name/wc3a (3, 3, 3, 128, 256)

tensor_name: var_name/wd1 (8192, 4096)

tensor_name: var_name/wc5b (3, 3, 3, 512, 512)

tensor_name: var_name/bd1 (4096,)

tensor_name: var_name/wd2 (4096, 4096)

tensor_name: var_name/wout (4096, 101)

tensor_name: var_name/wc1 (3, 3, 3, 3, 64)

tensor_name: var_name/bc4b (512,)

tensor_name: var_name/wc2 (3, 3, 3, 64, 128)

tensor_name: var_name/bc3a (256,)

tensor_name: var_name/bd2 (4096,)

tensor_name: var_name/bc5a (512,)

tensor_name: var_name/bc2 (128,)

tensor_name: var_name/bc5b (512,)

tensor_name: var_name/bout (101,)

tensor_name: var_name/bc4a (512,)

tensor_name: var_name/bc3b (256,)

tensor_name: var_name/wc4b (3, 3, 3, 512, 512)

tensor_name: var_name/bc1 (64,)

tensor_name: var_name/wc3b (3, 3, 3, 256, 256)

tensor_name: var_name/wc5a (3, 3, 3, 512, 512)
 

你可能感兴趣的:(tensorflow)