Tensorflow 模型加载及部分变量初始化

最近在做预训练部分图模型,将这部分图模型重新加载到一个新的图中,并加入一些新的op。下面是一些遇到的问题,调试方法以及解决方案。

1、从已有图中restore参数

saver_restore = tf.train.import_meta_graph(meta_path_restore)
saver_restore.restore(sess,)

2、通过tensor的名字获取变量

input_y = saver_restore.get_tensor_by_name('name:0')

P.S.在实验过程中,我自己尝试了一种方法,在外部创建会话,直接将需要加载的参数通过会话加载进来也是可以的。

sess = tf.Session()
model1= Model1(Session=sess, restore='Model1参数path')
model2= Model2(Session=sess, restore='Model2参数path')

这样也是可行的,但是如果加入其它操作的话会出现attempting to use uninitialized value。出现这个问题的原因是在部分加载变量后,添加了其它操作,同时又没有做对未赋值变量做initialize。解决办法

1、先添加需要的其它操作,然后运行

tf.global_variables_initializer()

随后对部分参数加载。这里需要注意,如果是最后执行global_variables_initializer()的话,之前所有的赋值操作都会被覆盖掉,也即之前做的所有操作都是无意义的。

2、加载参数后,加入新的操作,最后对没有初始化的部分参数进行初始化操作。

uninit_vars=[]
for var in tf.all_variables():
    try:
       sess.run(var)
    except tf.errors.FailedPreconditionError:
       uninit_vars.append(var)
init_new_vars_op = tf.initialize_variables(uninit_vars)
sess.run(init_new_vars_op)

这里使用的是部分参数初始化,通过这种方法就可以避免需要加载参数后再加入其他操作无法初始化参数的问题。

附加一个可以用于查看参数变量的代码,方便调试使用:

var = tf.trainable_variables()
value = sess.run(var)
for v in value:
    print(v)

另https://blog.csdn.net/ying86615791/article/details/76215363这篇写的也很好,可供参考。

你可能感兴趣的:(Tensorflow 模型加载及部分变量初始化)