以下函数获取scope_name命名空间下变量名为var_name的变量,不存在创建,存在则返回已存在的变量
import tensorflow as tf
sess = tf.Session()
def get_scope_variable(scope_name, var_name, shape=None):
with tf.variable_scope(scope_name) as scope:
try:
var = tf.get_variable(var_name, shape)
except ValueError:
scope.reuse_variables()
var = tf.get_variable(var_name)
return var
var_1 = get_scope_variable("cur_scope", "my_var", [100])
var_2 = get_scope_variable("cur_scope", "my_var", [100])
print(var_1 is var_2)
print(var_1.name) # 此时变量名为 cur_scope/my_var
print(var_2.name)
reuse设置为true不存在会异常,设置为False,存在重名会异常。故我们捕获异常来判断是否存在。 reuse=tf.AUTO_REUSE会自动识别,比较好用
if finetune:
weight = tf.constant(data_dict[name][0], name="weights")
bias = tf.constant(data_dict[name][1], name="bias")
# print("finetune")
else:
weight = tf.get_variable('weights',initializer=tf.truncated_normal([in_channel, out_channel], stddev=0.1))
bias = tf.get_variable('bias',initializer=tf.constant(0.1, dtype=tf.float32, shape=[out_channel]), trainable=True)