Tensorflow从1.3版本开始推出了官方支持的高层封装tf.estimator
。Estimators API提供了一整套训练模型、测试模型以及生成预测的方法。
自定义模型函数
Tensorflow支持自定义estimator,首先需要定义一个模型函数model_fn,函数有4个输入:features,labels,mode和params。
features为模型的输入,labels为预测的真实值
mode的取值有3种:tf.estimator.ModeKeys.TRAIN
,tf.estimator.ModeKeys.EVAL
和tf.estimator.ModeKeys.PREDICT
,分别对应训练,验证和测试。通过mode的值,可以判断当前属于哪一个阶段。params是一个字典,包含模型相关的超参数,例如learning rate等。
自定义函数model_fn返回值必须是一个tf.estimator.EstimatorSpec
对象,
def __new__(cls,
mode,
predictions=None,
loss=None,
train_op=None,
eval_metric_ops=None,
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None):
其中,mode
表示模型的使用模式,对应model_fn的参数mode;predictions
表示根据输入的特征features
计算返回的预测值;loss
表示损失;train_op
表示对模型的损失进行最小化的op;eval_metric_ops
表示模型在eval时,需要额外输出的指标。export_outputs
表示导出模型的路径。还有一些钩子函数。
当mode不同,EstimatorSpec所需的参数也不一样。如果mode为TRAIN
,则实例化EstimatorSpec时,必须设置参数loss
和train_op
,当mode为EVAL
时,必须设置参数loss
,当mode为PREDICT
时,必须设置参数predictions
。
def my_model(features, labels, mode, params):
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
predictions = tf.multiply(W, tf.cast(features, dtype=tf.float32)) + b
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode, predictions=predictions)
loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
mean_loss = tf.metrics.mean(loss)
metrics = {'mean_loss':mean_loss}
if mode == tf.estimator.ModeKeys.EVAL:
# eval_metric_ops`用来定义评价指标,在运行eval的时候会计算这里定义的所有评测标准。
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=metrics)
assert mode == tf.estimator.ModeKeys.TRAIN
optimizer = tf.train.AdagradDAOptimizer(learning_rate=params["learning_rate"], global_step=tf.train.get_or_create_global_step())
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_or_create_global_step())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
实例化estimator
最后通过实例化tf.estimator.Estimator
就可以得到一个自定义的estimator。
def __init__(self,
model_fn: Any,
model_dir: Any = None,
config: Any = None,
params: Any = None,
warm_start_from: Any = None) -> Any
参数model_fn
即为自定义的模型函数,model_dir
用于保存模型的参数和模型图等内容。warm_start_from
用来指定检查点路径,并导入checkpoint开始训练。warm_start_from可以通过tf.estimator.WarmStartSettings
实例化。
def __new__(cls,
ckpt_to_initialize_from: Any,
vars_to_warm_start: str = '.*',
var_name_to_vocab_info: Any = None,
var_name_to_prev_var_name: Any = None) -> _T
ckpt_to_initialize_from
可以指定加载checkpoint的路径,vars_to_warm_start
指定哪些参数需要热启动。
代码
参考
- 深度学习之tensorflow工程化项目实战。