checkpoints

this document examines how to save and restore tensorflow models built with estimators. tensorflow provides two model formats:
checkpoints, which is a format dependent on the code that created the model
SavedModel, which is a format independent of the code that created the model


for details on SavedModel, see the saving and restoring chapter of the tensorflow programmer's guide


checkpoints, which are versions of the model created during training
event files, which contain information that tensorboard uses to create visualizations


###
classifier = tf.estimator.DNNClassifier(
feature_columns = my_feature_columns,
hidden_units = [10, 10],
n_classes = 3,
model_dir = "models/iris"
)


if you don't specify model_dir in an estimator's constructor, the estimator writes checkpoint files to a temporary directory chosen by python's tempfile.mkdtemp function


by default, the estimator saves checkpoints in the model_dir according to the following schedule:
writes a checkpoint every 10 minutes(600 seconds)
writes a checkpoint when the train method starts (first iteration) and completes (final iteration)
retains only the 5 most recent checkpoints in the directory


you may alter the default schedule by taking the following steps:
create a RunConfig object that defines the desired schedule
when instantiating the estimator, pass that RunConfig object to the estimator's config argument


###
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 1 * 60,
keep_checkpoint_max = 10
)


classifier = tf.estimator.DNNClassifier(
feature_columns = my_feature_columns,
hidden_units = [10, 10],
n_classes = 3,
model_dir = "models/iris",
config = my_checkpointing_config
)


the estimator builds the model's graph by running the model_fn(). (for details on the model_fn(), see creating custom estimators)


the estimator initializes the weights of the new model from the data stored in the most recent checkpoint


to run experiments in which you train and compare slightly different versions of a model, save a copy of the code that created each model_dir, possibly by creating a separate git branch for each version. this seperation will keep your checkpoints recoverable

你可能感兴趣的:(Tensorflow)