tensorflow实现变量共享

如果想要达到重复利用变量的效果, 我们就要使用 tf.variable_scope(), 并搭配 tf.get_variable() 这种方式产生和提取变量. 不像 tf.Variable() 每次都会产生新的变量, tf.get_variable() 如果遇到了同样名字的变量时, 它会单纯的提取这个同样名字的变量(避免产生新变量). 而在重复使用的时候, 一定要在代码中强调 scope.reuse_variables(), 否则系统将会报错, 以为你只是单纯的不小心重复使用到了一个变量.

以下函数获取scope_name命名空间下变量名为var_name的变量,不存在创建,存在则返回已存在的变量

import tensorflow as tf
sess = tf.Session()
def get_scope_variable(scope_name, var_name, shape=None):
  with tf.variable_scope(scope_name) as scope:
    try:
      var = tf.get_variable(var_name, shape)
    except ValueError:
      scope.reuse_variables()
      var = tf.get_variable(var_name)
  return var

var_1 = get_scope_variable("cur_scope", "my_var", [100])
var_2 = get_scope_variable("cur_scope", "my_var", [100])
print(var_1 is var_2)
print(var_1.name)  # 此时变量名为  cur_scope/my_var
print(var_2.name)

reuse设置为true不存在会异常,设置为False,存在重名会异常。故我们捕获异常来判断是否存在。 reuse=tf.AUTO_REUSE会自动识别,比较好用

            if finetune:
                weight = tf.constant(data_dict[name][0], name="weights")
                bias = tf.constant(data_dict[name][1], name="bias")
                # print("finetune")
            else:
                weight = tf.get_variable('weights',initializer=tf.truncated_normal([in_channel, out_channel], stddev=0.1))
                bias = tf.get_variable('bias',initializer=tf.constant(0.1, dtype=tf.float32, shape=[out_channel]), trainable=True)

你可能感兴趣的:(tensorflow,tensorflow,python)