name_scope和variable_scope的区别

首先介绍两个创建variable的方法

  • tf.Variable(initial_value, name, dtype, trainable, collection)
  • tf.get_variable(name, shape, dtype, initializer, trainable, collection)

其中,tf.Variable每次调用都会创建一个新的变量,如果变量名字相同,就在后面加N:

first_a = tf.Variable(name='a', initial_value=1, dtype=tf.int32)
second_a = tf.Variable(name='a', initial_value=1, dtype=tf.int32)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(first_a.name)  # a_1:0
print(second_a.name)  # a_2:0

而,tf.get_variable的做法是,如果这个变量名字已经存在了,就拿这个变量,不再创建新的变量。
但是需要注意的是,一定要在scope中,使用reuse这个选项,如下是错误的。

first_a = tf.get_variable(name='a', shape=(1), initializer=tf.zeros_initializer, dtype=tf.int32)
second_a = tf.get_variable(name='a', shape=(1), initializer=tf.zeros_initializer, dtype=tf.int32)

不使用reuse是不能get_variable相同名字的变量的;而使用resue又只能在variable_scope中:

with tf.variable_scope('var_scope') as scope:
    v = tf.get_variable(name='v', shape=[1], initializer=tf.zeros_initializer)
with tf.variable_scope(scope, reuse=True):
    v1 = tf.get_variable(name='v', shape=[1], initializer=tf.zeros_initializer)
sess = tf.Session()
sess.run(tf.global_variables_initializer())

assert v == v1
print(v.name)   #var_scope/v:0
print(v1.name)  #var_scope/v:0

tf.name_scope中

tf.get_variable不起作用,只对tf.Variable起作用

with tf.name_scope("my_scope"):
    v1 = tf.get_variable("var1", [1], dtype=tf.float32)
    v2 = tf.Variable(1, name="var2", dtype=tf.float32)
    a = tf.add(v1, v2)

print(v1.name)  # var1:0
print(v2.name)  # my_scope/var2:0
print(a.name)   # my_scope/Add:0

tf.variable_scope中

tf.get_variabletf.Variable都起作用

with tf.variable_scope("my_scope"):
    v1 = tf.get_variable("var1", [1], dtype=tf.float32)
    v2 = tf.Variable(1, name="var2", dtype=tf.float32)
    a = tf.add(v1, v2)

print(v1.name)  # my_scope/var1:0
print(v2.name)  # my_scope/var2:0
print(a.name)   # my_scope/Add:0

这种机制允许在不用的name_scope中使用tf.get_variable来share变量,但是需要注意的是,一定要声明reuse:

with tf.name_scope("foo"):
    with tf.variable_scope("var_scope"):
        v = tf.get_variable("var", [1])
with tf.name_scope("bar"):
    with tf.variable_scope("var_scope", reuse=True):
        v1 = tf.get_variable("var", [1])
assert v1 == v
print(v.name)   # var_scope/var:0
print(v1.name)  # var_scope/var:0****

你可能感兴趣的:(name_scope和variable_scope的区别)