自定义estimator

Tensorflow从1.3版本开始推出了官方支持的高层封装tf.estimator。Estimators API提供了一整套训练模型、测试模型以及生成预测的方法。

自定义模型函数

Tensorflow支持自定义estimator,首先需要定义一个模型函数model_fn,函数有4个输入:features,labels,mode和params。
features为模型的输入,labels为预测的真实值
mode的取值有3种:tf.estimator.ModeKeys.TRAINtf.estimator.ModeKeys.EVALtf.estimator.ModeKeys.PREDICT,分别对应训练,验证和测试。通过mode的值,可以判断当前属于哪一个阶段。params是一个字典,包含模型相关的超参数,例如learning rate等。
自定义函数model_fn返回值必须是一个tf.estimator.EstimatorSpec对象,

  def __new__(cls,
              mode,
              predictions=None,
              loss=None,
              train_op=None,
              eval_metric_ops=None,
              export_outputs=None,
              training_chief_hooks=None,
              training_hooks=None,
              scaffold=None,
              evaluation_hooks=None,
              prediction_hooks=None):

其中,mode表示模型的使用模式,对应model_fn的参数mode;predictions表示根据输入的特征features计算返回的预测值;loss表示损失;train_op表示对模型的损失进行最小化的op;eval_metric_ops表示模型在eval时,需要额外输出的指标。export_outputs表示导出模型的路径。还有一些钩子函数。
当mode不同,EstimatorSpec所需的参数也不一样。如果mode为TRAIN,则实例化EstimatorSpec时,必须设置参数losstrain_op,当mode为EVAL时,必须设置参数loss,当mode为PREDICT时,必须设置参数predictions

def my_model(features, labels, mode, params):
    W = tf.Variable(tf.random_normal([1]), name="weight")
    b = tf.Variable(tf.zeros([1]), name="bias")
    predictions = tf.multiply(W, tf.cast(features, dtype=tf.float32)) + b
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)
    loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions)
    mean_loss = tf.metrics.mean(loss)
    metrics = {'mean_loss':mean_loss}
    if mode == tf.estimator.ModeKeys.EVAL:
        # eval_metric_ops`用来定义评价指标,在运行eval的时候会计算这里定义的所有评测标准。
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=metrics) 
    assert mode == tf.estimator.ModeKeys.TRAIN
    optimizer = tf.train.AdagradDAOptimizer(learning_rate=params["learning_rate"], global_step=tf.train.get_or_create_global_step())
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_or_create_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

实例化estimator

最后通过实例化tf.estimator.Estimator就可以得到一个自定义的estimator。

def __init__(self,
             model_fn: Any,
             model_dir: Any = None,
             config: Any = None,
             params: Any = None,
             warm_start_from: Any = None) -> Any

参数model_fn即为自定义的模型函数,model_dir用于保存模型的参数和模型图等内容。warm_start_from用来指定检查点路径,并导入checkpoint开始训练。warm_start_from可以通过tf.estimator.WarmStartSettings实例化。

def __new__(cls,
            ckpt_to_initialize_from: Any,
            vars_to_warm_start: str = '.*',
            var_name_to_vocab_info: Any = None,
            var_name_to_prev_var_name: Any = None) -> _T

ckpt_to_initialize_from可以指定加载checkpoint的路径,vars_to_warm_start指定哪些参数需要热启动。

代码

代码自定义estimator

参考

  1. 深度学习之tensorflow工程化项目实战。

你可能感兴趣的:(tensorflow)