tf.Variable()参数
tf.Variable(initial_value=None,
trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE,
shape=None
)
经常使用的参数有initial_value
、name
、shape
三个,分别是初始化,命名和规定所需要的形状大小。举个例子:
import tensorflow as tf
v1=tf.Variable(tf.random_normal(shape=[4,3],mean=0,stddev=1),name='v1')
v2=tf.Variable(tf.constant(2),name='v2')
v3=tf.Variable(tf.ones([4,3]),name='v3')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print 'v1:\n',sess.run(v1)
print 'v2:\n',sess.run(v2)
print 'v3:\n',sess.run(v3)
运行结果
v1:
[[ 0.4027793 0.72299665 -1.4619899 ]
[-1.7155927 -0.8806208 -0.39554796]
[-0.4185343 -1.562368 1.9035501 ]
[-0.7704326 -1.9970375 2.224315 ]]
v2:
2
v3:
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
tf.get_variable()参数
tf.get_variable(name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None,
synchronization=tf.VariableSynchronization.AUTO,
aggregation=tf.VariableAggregation.NONE
)
与tf.Variable()
一样,经常使用的参数有initial_value
、name
、shape
三个,分别是初始化,命名和规定所需要的形状大小。举个例子:
import tensorflow as tf
v1 = tf.get_variable(name='v1', shape=[2,3], initializer=tf.random_normal_initializer(mean=0, stddev=1))
v2 = tf.get_variable(name='v2', shape=[1], initializer=tf.constant_initializer(1))
v3 = tf.get_variable(name='v3', shape=[2,3], initializer=tf.ones_initializer())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(v1)
print sess.run(v2)
print sess.run(v2)
运行结果如下:
v1:
[[-0.06989016 0.44355923 -1.2255034 ]
[ 0.46685636 -0.8572208 -0.16504966]]
v2:
[1.]
v3:
[[1. 1. 1.]
[1. 1. 1.]]
tf.Variable()、tf.get_variable() 两者区别
tf.get_variable创建变量时,会进行变量检查,当设置为共享变量时(通过with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE)设置),检查到第二个拥有相同名字的变量,就返回已创建的相同的变量;如果没有设置共享变量,则会报ValueError: Variable varx alreadly exists, disallowed
的错误。而tf.Variable()创建变量时,name属性值允许重复,检查到相同名字的变量时,由自动别名机制创建不同的变量。举个例子:
with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE):
var1 = tf.get_variable(name='var1', shape=[1], initializer=None, dtype=tf.float32)
var11 = tf.get_variable(name='var1')
var2 = tf.Variable(name='var2', initial_value=[1], dtype=tf.float32)
var21 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
with tf.Session() as sess:
print var1.name
print var11.name
print var2.name
print var21.name
输出name时,如下:
var1:0
var1:0
name_scope_2/var2:0
name_scope_2/var2_1:0