tensorflow2.x变量初始化

TensorFlow2.x和PyTorch已经很像了,可以直接输出Tensor的值,不用sess.run()了,很直观。

tf2.x其中一个特点是去掉了很多冗余API,tf1.x中,定义变量有两个函数,tf.Variable()和tf.get_variable(),在tf2.x的API中,只剩tf.Variable()了,另一个变成tf.compat.v1.get_variable()了。

tf.Variable(
    initial_value=None, trainable=None, validate_shape=True, caching_device=None,
    name=None, variable_def=None, dtype=None, import_scope=None, constraint=None,
    synchronization=tf.VariableSynchronization.AUTO,
    aggregation=tf.compat.v1.VariableAggregation.NONE, shape=None
)

变量在深度神经网络中一般是可学习的参数,因为每一轮迭代都需要根据反向传播修改这些可学习参数,所以要定义成变量。变量在使用前都会进行赋值,叫做初始化。

tf2.x中,tf.Variable()初始化张量的API有两类,一类是一个函数,比如

tf.random.normal(
    shape, mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None, name=None
)

同时给出张量的shape和值。调用代码:

kernel = tf.Variable(initial_value= tf.random.normal([9,9,3,1]))

还有一类是一个类,比如

tf.random_normal_initializer(
    mean=0.0, stddev=0.05, seed=None
)

这些类的构造函数中不需要指定张量的shape,如果要在tf.Variable()使用这种类,需要调用这种类的__call__方法,

def __call__(self, shape, dtype=dtypes.float32):调用代码:
kernel = tf.Variable(initial_value=tf.random_normal_initializer()(shape=[9, 9, 3, 1], dtype=tf.float32))

 

你可能感兴趣的:(tensorflow)