Variables: 创建、初始化、保存和加载

引言

当你训练一个模型的时候,你使用变量去保存和更新参数。在Tensorflow中变量是内存缓冲区中保存的张量(tensor)。它们必须被显示的初始化,可以在训练完成之后保存到磁盘上。之后,你可以重新加载这些值用于测试和模型分析。
本篇文档引用了如下的Tensorflow类。以下的链接指向它们更加详细的API:

  • tf.Variable 类。
  • tf.train.Saver 类。

创建

当你创建一个变量时,你传递一个tensor数据作为它的初始值给Variable()构造器。Tensorflow提供了一堆操作从常量或者随机值中产生tensor数据用于初始化。
注意这些操作要求你指定tensor数据的形状。这个形状自动的成为变量的形状。变量的形状是固定的。不过,Tensorflow提供了一些高级机制用于改变变量的形状。

# 创建两个变量
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")

调用tf.Variable()会在计算图上添加这些节点:

  • 一个变量节点,用于保存变量的值
  • 一个初始化操作节点,用于将变量设置为初始值。它实际上是一个tf.assign节点。
  • 初始值节点,例如例子中的zeros节点也会被加入到计算图中。

tf.Variable()返回值是Python类tf.Variable的一个实例。

设备配置

一个变量在创建时可以被塞进制定的设备,通过使用 with tf.device(…)::

# 将变量塞进CPU里
with tf.device("/cpu:0"):
  v = tf.Variable(...)

# 将变量塞进GPU里
with tf.device("/gpu:0"):
  v = tf.Variable(...)

# 将变量塞进指定的参数服务任务里
with tf.device("/job:ps/task:7"):
  v = tf.Variable(...)

注意一些改变变量的操作,例如v.assign()和在tf.train.Optimizer中变量的更新操作,必须与与变量创建时运行在同一设备上。创建这些操作是,不兼容的设备配置将会忽略。

初始化

变量的初始化必须找模型的其他操作之前,而且必须显示的运行。最简单的方式是添加一个节点用于初始化所有的变量,然后在使用模型之前运行这个节点。
或者你可以选择从checkpoint文件中加载变量,之后将会介绍。
使用tf.global_variables_initializer()添加节点用于初始化所有的变量。在你构建完整个模型并在会话中加载模型后,运行这个节点。

# 创建两个变量
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")
...
# 添加用于初始化变量的节点
init_op = tf.global_variables_initializer()

# 然后,在加载模型的时候
with tf.Session() as sess:
  # 运行初始化操作
  sess.run(init_op)
  ...
  # 使用模型
  ...

从别的变量中初始化

有时候你需要利用另一个变量来初始化当前变量。由于tf.global_variables_initializer()添加的节点适用于并行的初始化所有变量,所有如果你有这个需求,你得小心谨慎。
为了从另一个变量中初始化一个新的变量,使用变量的另一个方法initialized_value()。你可以直接将旧变量的初始值作为新变量的初始值,或者你可以将旧变量的初始值进行一些运算后再作为新变量的初始值。

# 使用随机数创建一个变量
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
# 创建另一个变量,它与weights拥有相同的初始值
w2 = tf.Variable(weights.initialized_value(), name="w2")
# 创建另一个变量,它的初始值是weights的两倍
w_twice = tf.Variable(weights.initialized_value() * 2.0, name="w_twice")

自定义初始化

tf.global_variables_initializer()能够将所有的变量一步到位的初始化,非常的方便。你也可以将指定的列表传递给它,只初始化列表中的变量。 更多的选项请查看Variables Documentation,包括检查变量是否初始化。

保存和加载

最简单的保存和加载模型的方法是使用tf.train.Saver 对象。它的构造器将在计算图上添加save和restore节点,针对图上所有或者指定的变量。saver对象提供了运行这些节点的方法,只要指定用于读写的checkpoint的文件。

checkpoint文件

变量以二进制文件的形式保存在checkpoint文件中,粗略地来说就是变量名与tensor数值的一个映射
当你创建一个Saver对象是,你可以选择变量在checkpoint文件中名字。默认情况下,它会使用Variable.name作为变量名。
为了理解什么变量在checkpoint文件中,你可以使用inspect_checkpoint库,更加详细地,使用print_tensors_in_checkpoint_file函数。

保存变量

使用tf.train.Saver()创建一个Saver对象,然后用它来管理模型中的所有变量。

# 创建一些变量
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 添加用于初始化变量的节点
init_op = tf.global_variables_initializer()

# 添加用于保存和加载所有变量的节点
saver = tf.train.Saver()

# 然后,加载模型,初始化所有变量,完成一些操作后,把变量保存到磁盘上
with tf.Session() as sess:
  sess.run(init_op)
  # 进行一些操作
  ..
  # 将变量保存到磁盘上
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

加载变量

Saver对象还可以用于加载变量。注意当你从文件中加载变量是,你不用实现初始化它们。

# 创建两个变量
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 添加用于保存和加载所有变量的节点
saver = tf.train.Saver()

# 然后,加载模型,使用saver对象从磁盘上加载变量,之后再使用模型进行一些操作
with tf.Session() as sess:
  # 从磁盘上加载对象
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # 使用模型进行一些操作
  ...

选择变量进行保存和加载

如果你不传递任何参数给tf.train.Saver(),Saver对象将处理图中的所有变量。每一个变量使用创建时传递给它的名字保存在磁盘上。
有时候,我们需要显示地指定变量保存在checkpoint文件中的名字。例如,你可能使用名为“weights”的变量训练模型;在保存的时候,你希望用“params”为名字保存。
有时候,我们只保存和加载模型的部分参数。例如,你已经训练了一个5层的神经网络;现在你想训练一个新的神经网络,它有6层。加载旧模型的参数作为新神经网络前5层的参数。
通过传递给tf.train.Saver()一个Python字典,你可以简单地指定名字和想要保存的变量。字典的keys是保存在磁盘上的名字,values是变量的值。
注意:
如果你需要保存和加载不同子集的变量,你可以随心所欲地创建任意多的saver对象。同一个变量可以被多个saver对象保存。它的值仅仅在restore()方法运行之后发生改变。
如果在会话开始之初,你仅加载了部分变量,你还需要为其他变量运行初始化操作。参见tf.initialize_variables() 查询更多的信息。

# 创建一些对象
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 添加一个节点用于保存和加载变量v2,使用名字“my_v2”
saver = tf.train.Saver({"my_v2": v2})
# Use the saver object normally after that.
...

你可能感兴趣的:(Tensorflow)