TensorFlow中tf.train.Saver类说明

 

Saver类

类的详细信息在github可找到

功能:保存和恢复变量

有关变量,保存和恢复的概述,请参见变量

将Saver类添加ops 从而在checkpointes里save和restore变量 。它还提供了运行这些操作的便捷方法。

Checkpoints是专有格式的二进制文件,它将变量名称映射到张量值。测试Checkpoints内容的最佳方式是使用Saver来加载它

Savers可以使用提供的计数器自动为Checkpoint文件名编号,这使您可以在训练模型时在不同的步骤中保留多个Checkpoints。例如,您可以使用训练步骤编号对Checkpoint文件名进行编号。为避免填满磁盘,储存器会自动管理Checkpoint文件。例如,他们只能保留N个最新文件,或每N小时训练一个Checkpoint。

 

可以通过将值传递给可选global_step参数来对检查点文件名进行编号 save:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

此外,Saver()构造函数的可选参数允许您控制磁盘上Checkpoint文件的扩散:

 

  • max_to_keep表示要保留的最近文件的最大数量。创建新文件时,将删除旧文件。如果为None或0,则不会从文件系统中删除任何Checkpoint,但只有最后一个Checkpoint保留在checkpoint文件中。默认为5(即保留最近的5个Checkpoint文件。)
  • keep_checkpoint_every_n_hours:除了保留最新的 max_to_keep检查点文件之外,您可能还希望每N小时的训练保留一个Checkpoint文件。如果您想稍后分析模型在长时间训练期间的进展情况,这将非常有用。例如,传递keep_checkpoint_every_n_hours=2确保每2小时训练保留一个检查点文件。默认值10,000小时可有效禁用该功能。

 

请注意,您仍然需要调用save()方法来保存模型。将这些参数传递给构造函数不会自动为您保存变量。

定期保存的训练代码如下:

.

...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

除了检查点文件之外,保存程序还在磁盘上保留协议缓冲区以及最近的Checkpoint列表。这用于管理编号的检查点文件,通过latest_checkpoint()它可以轻松发现最近检查点的路径。该协议缓冲区存储在Checkpoint文件同一目录下且名为“checkpoint”的文件中。

 

如果创建多个savers,则可以在调用中为协议缓冲区文件指定不同的文件名。 save():

 

__init__()

__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

创建一个Saver。

构造函数添加操作以保存和恢复变量。

var_list指定将保存和恢复的变量。它可以作为一个dict或一个列表传递:

  • 一个dict名字变量:键是将被用来在检查点文件保存或恢复变量名。
  • 变量列表:变量将在检查点文件中使用其op名称进行键控。

例如:

v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})

# Or pass them as a list.
saver = tf.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})

可选reshape参数if True,允许从保存文件中恢复变量,其中变量具有不同的shape,但元素和类型的数量相同。如果您已重新变换变量并希望从较旧的检查点重新加载它,则此选项非常有用.

可选sharded参数if True指示保护程序为每个设备分片检查点。

 

__init__参数说明:

  • var_list:Variable/ 的列表SaveableObject,或者将名称映射到SaveableObjects 的字典。如果None,默认为所有可保存对象的列表。
  • reshape:If True,允许从变量具有不同形状的检查点恢复参数。
  • sharded:如果True,将每个设备分成一个检查点。
  • max_to_keep:要保留的最近检查点的最大数量。默认为5。
  • keep_checkpoint_every_n_hours:保持检查站的频率。默认为10,000小时。
  • name:字符串。添加操作时用作前缀的可选名称。
  • restore_sequentially:A Bool,如果为true,则导致不同变量的恢复在每个设备中顺序发生。这可以在恢复非常大的模型时降低内存使用量。
  • saver_def:SaverDef使用可选的proto而不是运行构建器。这仅适用于想要Saver为先前构建的Graph具有a 的对象重新创建对象的专业代码Saver。该saver_def原型应该是返回一个 as_saver_def()的电话Saver说是为创建Graph。
  • builder:SaverBuilder如果saver_def未提供,则可以选择使用。默认为BulkSaverBuilder()。
  • defer_build:如果True,请将保存和恢复操作添加到 build()呼叫中。在这种情况下,build()应在最终确定图表或使用保护程序之前调用。
  • allow_empty:如果False(默认)如果图中没有变量则引发错误。否则,无论如何构建保护程序并使其成为无操作。
  • write_version:控制保存检查点时使用的格式。它还会影响某些文件路径匹配逻辑。V2格式是推荐的选择:它在恢复期间所需的内存和延迟方面比V1更加优化。无论此标志如何,Saver都能够从V2和V1检查点恢复。
  • pad_step_number:如果为True,则将检查点文件路径中的全局步骤编号填充到某个固定宽度(默认为8)。默认情况下这是关闭的。
  • save_relative_paths:如果True,将写入检查点状态文件的相对路径。如果用户想要复制检查点目录并从复制的目录重新加载,则需要这样做。
  • filename:如果在图形构造时知道,则用于变量加载/保存的文件名。

 

Raises:

  • TypeError:如果var_list无效。
  • ValueError:如果任何键或值var_list不是唯一的。
  • RuntimeError:如果启用了急切执行,var_list并且未指定要保存的变量列表。

 

Eager Compatibility

启用eager执行时,var_list必须指定 要保存的变量list或dict变量。否则,将出现RuntimeError。

 

属性

last_checkpoints

尚未删除的检查点文件名列表。

您可以将任何返回的值传递给restore()

 

返回值:检查点文件名列表,从最旧到最新排序。

 

方法

 

as_saver_def()

生成saver的表示。

返回:SaverDef的原因

 

build()

 

export_meta_graph(
    filename=None,
    collection_list=None,
    as_text=False,
    export_scope=None,
    clear_devices=False,
    clear_extraneous_savers=False,
    strip_default_attrs=False
)

写入MetaGraphDefsave_path / filename。

参数说明:

  • filename:可选的meta_graph文件名,包括路径。
  • collection_list:要收集的字符串键列表。
  • as_text:如果True,将meta_graph写为ASCII原型。
  • export_scope:可选string。要删除的名称范围。
  • clear_devices:是否要清除设备领域的一个Operation 或Tensor导出过程中。
  • clear_extraneous_savers:从图表中删除与此Saver无关的任何与Saver相关的信息(Save / Restore ops和SaverDefs)。
  • strip_default_attrs:布尔值。如果True,将从NodeDefs中删除默认值属性。有关详细指南,请参阅Stripping Default-Valued Attributes 。

 

 

 

@staticmethod
from_proto(
    saver_def,
    import_scope=None
)

返回Saver从中创建的对象saver_def

 

参数说明:

  • saver_def:SaverDef协议缓冲区。
  • import_scope:可选string。命名范围使用。

返回值:

从saver_def构造的Saver。

 

recover_last_checkpoints(checkpoint_paths)

崩溃后恢复内部保护状态。

此方法对于恢复“self._last_checkpoints”状态很有用。

全局的Checkpoints-checkpoint_paths,如果文件存在,请使用其mtime作为Checkpoint时间戳。

参数说明:

  • checkpoint_paths:检查点路径列表。

 

 

restore(
    sess,
    save_path
)

恢复以前保存的变量。

此方法运行构造函数添加的ops以恢复变量。它需要启动图表的会话。要恢复的变量不必初始化,因为恢复本身就是一种初始化变量的方法。

该save_path参数通常是先前从save()调用或调用返回的值 latest_checkpoint()。

 

参数说明:

  • sess:a Session用于恢复参数。None 是默认模式。
  • save_path:先前保存参数的路径。

 

可能出现的异常:

  • ValueError:如果save_path为None或者不是有效的CheckPoint。

 

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True,
    strip_default_attrs=False
)

保存变量。

此方法运行构造函数添加的ops以保存变量。它需要启动图表的会话。要保存的变量也必须已初始化。

该方法返回新创建的检查点文件的路径前缀。该字符串可以直接传递给调用restore()。

 

参数说明:

  • sess:用于保存变量的会话。
  • save_path:字符串。为检查点创建的文件名的前缀。
  • global_step:如果提供,则附加全局步骤编号 save_path以创建检查点文件名。可选参数可以是a Tensor,Tensor名称或整数。
  • latest_filename:协议缓冲区文件的可选名称,其中包含最新检查点的列表。该文件与检查点文件保存在同一目录中,由保护程序自动管理,以跟踪最近的检查点。默认为'checkpoint'。
  • meta_graph_suffix:MetaGraphDef文件的后缀。默认为'meta'。
  • write_meta_graph:Boolean指示是否编写元图文件。
  • write_state:Boolean表示是否写入 CheckpointStateProto。
  • strip_default_attrs:布尔值。如果True,将从NodeDefs中删除默认值属性。有关详细指南,请参阅 sess:用于保存变量的会话。
  • save_path:字符串。为检查点创建的文件名的前缀。
  • global_step:如果提供,则附加全局步骤编号 save_path以创建检查点文件名。可选参数可以是a Tensor,Tensor名称或整数。
  • latest_filename:协议缓冲区文件的可选名称,其中包含最新检查点的列表。该文件与检查点文件保存在同一目录中,由保护程序自动管理,以跟踪最近的检查点。默认为'checkpoint'。
  • meta_graph_suffix:MetaGraphDef文件的后缀。默认为'meta'。
  • write_meta_graph:Boolean指示是否编写元图文件。
  • write_state:Boolean表示是否写入 CheckpointStateProto。
  • strip_default_attrs:布尔值。如果True,将从NodeDefs中删除默认值属性。有关详细指南,请参阅 Stripping Default-Valued Attributes 。

返回值:

字符串:用于检查点文件的路径前缀。如果保护程序是分片的,则该字符串以:' - ????? - of -nnnnn',其中'nnnnn'是创建的分片数。如果保护程序为空,则返回None。

 

可能出现的异常:

  • TypeError:如果sess不是Session。
  • ValueError:如果latest_filename包含路径组件,或者它与之冲突save_path。
  • RuntimeError:如果没有构建保存和恢复操作。

 

set_last_checkpoints(last_checkpoints)

弃用:使用set_last_checkpoints_with_time

设置旧的Checkpoint文件名列表

 

参数说明:

  • last_checkpoints:检查点文件名列表。

可能出现的异常:

  • AssertionError:如果last_checkpoints不是列表。

 

set_last_checkpoints_with_time(last_checkpoints_with_time)

设置旧Checkpoint文件名和时间戳的列表。

参数说明:

  • last_checkpoints_with_time:检查点文件名和时间戳的元组列表。

 

可能出现的异常:

  • AssertionError:如果last_checkpoints_with_time不是列表

 

to_proto(export_scope=None)

参数说明:

  • export_scope:可选string。要删除的名称范围。

返回值:SaverDef协议缓冲器

参考资料

https://www.tensorflow.org/api_docs/python/tf/train/Saver

你可能感兴趣的:(TensorFlow)