tf.Variable()和tf.get_variable()

tf.Variable()参数

tf.Variable(initial_value=None,
            trainable=None,
            collections=None,
            validate_shape=True,
            caching_device=None,
            name=None,
            variable_def=None,
            dtype=None,
            expected_shape=None,
            import_scope=None,
            constraint=None,
            use_resource=None,
            synchronization=tf.VariableSynchronization.AUTO,
            aggregation=tf.VariableAggregation.NONE,
            shape=None
          )

经常使用的参数有initial_valuenameshape三个,分别是初始化,命名和规定所需要的形状大小。举个例子:

import tensorflow as tf
v1=tf.Variable(tf.random_normal(shape=[4,3],mean=0,stddev=1),name='v1')
v2=tf.Variable(tf.constant(2),name='v2')
v3=tf.Variable(tf.ones([4,3]),name='v3')
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print 'v1:\n',sess.run(v1)
    print 'v2:\n',sess.run(v2)
    print 'v3:\n',sess.run(v3)

运行结果

v1:
[[ 0.4027793   0.72299665 -1.4619899 ]
 [-1.7155927  -0.8806208  -0.39554796]
 [-0.4185343  -1.562368    1.9035501 ]
 [-0.7704326  -1.9970375   2.224315  ]]

v2:
2

v3:
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]

tf.get_variable()参数

tf.get_variable(name,
                shape=None,
                dtype=None,
                initializer=None,
                regularizer=None,
                trainable=None,
                collections=None,
                caching_device=None,
                partitioner=None,
                validate_shape=True,
                use_resource=None,
                custom_getter=None,
                constraint=None,
                synchronization=tf.VariableSynchronization.AUTO,
                aggregation=tf.VariableAggregation.NONE
              )

tf.Variable()一样,经常使用的参数有initial_valuenameshape三个,分别是初始化,命名和规定所需要的形状大小。举个例子:

import tensorflow as tf

v1 = tf.get_variable(name='v1', shape=[2,3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(1))
v3 = tf.get_variable(name='v3', shape=[2,3], initializer=tf.ones_initializer())
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(v1)
    print sess.run(v2)
    print sess.run(v2)

运行结果如下:

v1:
[[-0.06989016  0.44355923 -1.2255034 ]
 [ 0.46685636 -0.8572208  -0.16504966]]

v2:
[1.]

v3:
[[1. 1. 1.]
 [1. 1. 1.]]

tf.Variable()、tf.get_variable() 两者区别

tf.get_variable创建变量时,会进行变量检查,当设置为共享变量时(通过with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE)设置),检查到第二个拥有相同名字的变量,就返回已创建的相同的变量;如果没有设置共享变量,则会报ValueError: Variable varx alreadly exists, disallowed的错误。而tf.Variable()创建变量时,name属性值允许重复,检查到相同名字的变量时,由自动别名机制创建不同的变量。举个例子:

with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE):
    var1 = tf.get_variable(name='var1', shape=[1], initializer=None, dtype=tf.float32)
    var11 = tf.get_variable(name='var1')
    var2 = tf.Variable(name='var2', initial_value=[1], dtype=tf.float32)
    var21 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
    with tf.Session() as sess:
        print var1.name
        print var11.name
        print var2.name
        print var21.name

输出name时,如下:
var1:0
var1:0
name_scope_2/var2:0
name_scope_2/var2_1:0

你可能感兴趣的:(tf.Variable()和tf.get_variable())