TensorFlow的变量共享

tf.get_variable()方法是TensorFlow提供的比tf.Variable()稍微高级的创建/获取变量的方法,它的工作方式根据当前的变量域(Variable Scope)的reuse属性变化而变化,我们可以通过tf.get_variable_scope().reuse来查看这个属性,它默认是False

  1. tf.get_variable_scope().reuse == False
    此时调用tf.get_variable(name, shape, dtype, initializer),我们可以创建一个新的变量(或者说张量),这个变量的名字为name,维度是shape,数据类型是dtype,初始化方法是指定的initializer。如果名字为name的变量已经存在的话,会导致ValueError
    一个例子如下:
# create var
entity = tf.get_variable(name='entity', initializer=...)
  1. tf.get_variable_scope().reuse == True
    此时调用tf.get_variable(name),我们 可以 得到一个已经存在的名字为name的变量,如果这个变量不存在的话,会导致ValueError
    一个例子如下:
# reuse var
tf.get_variable_scope().reuse_variables()  # set reuse to True
entity = tf.get_variable(name='entity')

上面的两种情况得到的变量的名字都为name,这是假设在默认的变量域中调用tf.get_variable(),如果在指定的变量域中调用,比如:

# create var
with tf.variable_scope('embedding'):
    entity = tf.get_variable(name='entity', initializer=...)
# reuse var
with tf.variable_scope('embedding', reuse=True):
    entity = tf.get_variable(name='entity')

那么得到的变量entity的名字则是embedding/entity

你可能感兴趣的:(TensorFlow的变量共享)