tf.estimator.train_and_evaluate(
estimator,
train_spec,
eval_spec
)
定义于:
tensorflow/python/estimator/training.py。
训练和评估estimator
。
该通用函数使用给定的estimator
训练,评估和(可选地)导出模型。所有训练相关的规范都包含在train_spec
内,包括训练input_fn
和最大训练次数等。所有评估和导出相关的规范都包含在eval_spec
内,包括评估input_fn
,次数等。
此通用函数本地(非分布式)和分布式配置是一致的。目前,唯一支持的分布式培训配置是图间复制。
Overfitting:为了避免Overfitting,建议设置训练input_fn
以适当地改变训练数据。在进行评估之前,还建议将模型多训练一段时间,比如多个epochs,因为输入管道从头开始进行每次训练。这对本地的训练和评估尤为重要。
Stop condition:为了可靠地支持分布式和非分布式配置,模型训练唯一支持的Stop condition是train_spec.max_steps
。如果train_spec.max_steps
是None
,模型将永远训练下去。如果模型Stop condition不同,请小心使用。例如,假设预期模型将使用一个epoch训练数据进行训练,并且训练input_fn
被配置为 在经过一个epoch训练之后抛出OutOfRangeError
,停止训练Estimator.train
。对于 three-training-worker分布式配置,每个training worker可能在整个epoch独立地完成训练。因此,该模型将使用三个epoches训练数据而不是一个epoch进行训练。
本地(非分布式)训练示例:
# Set up feature columns.
categorial_feature_a = categorial_column_with_hash_bucket(...)
categorial_feature_a_emb = embedding_column(
categorical_column=categorial_feature_a, ...)
... # other feature columns
estimator = DNNClassifier(
feature_columns=[categorial_feature_a_emb, ...],
hidden_units=[1024, 512, 256])
# Or set up the model directory
# estimator = DNNClassifier(
# config=tf.estimator.RunConfig(
# model_dir='/my_model', save_summary_steps=100),
# feature_columns=[categorial_feature_a_emb, ...],
# hidden_units=[1024, 512, 256])
# Input pipeline for train and evaluate.
def train_input_fn: # returns x, y
# please shuffle the data.
pass
def eval_input_fn_eval: # returns x, y
pass
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
分布式训练示例:
关于分布式训练的示例,上面的代码可以在没有更改的情况下使用(请确保将RunConfig.model_dir
所有workers 设置为相同的目录,即所有workers 都可以读写的共享文件系统)。唯一需要做的额外工作是相应地为每个workers 正确设置环境变量TF_CONFIG
。
另请参阅 Distributed TensorFlow。
设置环境变量取决于平台。例如,在Linux上,它可以按如下方式完成($
是shell提示符):
$ TF_CONFIG='' python train_model.py
对于内容TF_CONFIG
,假设训练集群规范如下:
cluster = {"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]}
TF_CONFIG
主要训练workers 的例子(必须有一个且只有一个):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster": {
"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]
},
"task": {"type": "chief", "index": 0}
}'
请注意,主要workers 也进行模型训练工作,类似于其他非主要训练workers (见下一段)。除了模型训练之外,它还管理一些额外的工作,例如检查点保存和恢复,写入summaries等。
TF_CONFIG
非主要训练workers 的示例(可选,可以是多个):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster": {
"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]
},
"task": {"type": "worker", "index": 0}
}'
其中task.index
应分别设定为0,1,2,在这个例子中,用于非主要训练workers 。
TF_CONFIG
参数服务器的示例,也就是ps(可能是多个):
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster": {
"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]
},
"task": {"type": "ps", "index": 0}
}'
其中task.index
应分别设置为0和1,在本例中,分别为参数服务器。
TF_CONFIG
评估任务的示例。Evaluator是一项特殊任务,不属于训练集群。可能只有一个,它用于模型评估。
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
TF_CONFIG='{
"cluster": {
"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
"ps": ["host4:2222", "host5:2222"]
},
"task": {"type": "evaluator", "index": 0}
}'
参数:
estimator
:Estimator
训练和评估的实例。
train_spec
:TrainSpec
指定训练规范的实例。
eval_spec
:EvalSpec
指定评估和导出规范的实例。
返回:
evaluate的结果元组,和指定ExportStrategy导出的结果。
目前,分布式训练模式的返回值未定义。