TensorFlow入门16: 创建自定义的Estimator 2

上节分析,我们知道创建自定义的Estimator,其主要工作就是编写自定义的模型函数。

由于模型函数是传给Estimator使用的,所以,打开~\Anaconda3\Lib\site-packages\tensorflow\python\estimator文件夹下estimator.py文件,可以看到对模型函数的要求:

细节,全在源代码中,建议大家一边阅读源代码,一边看本文

TensorFlow入门16: 创建自定义的Estimator 2_第1张图片

由上图可以看到:`Estimator` 对象封装了神经网络模型(model),这个模型由model_fn的参数指定(specify),model_fn会返回必要的操作(ops),用于训练、评估和预测。

model_fn的函数签名,如下图所示:


TensorFlow入门16: 创建自定义的Estimator 2_第2张图片
The signature of model_fn

*基础知识:函数签名( function signature), 又叫做 type signature, 或 method signature,定义了 函数 或 方法.的输入输出,主要包括:

1)parameters and their types

2)a return value and type

3)exceptions that might be thrown or passed back

4)information about the availability of the method in an object-oriented program (such as the keywords public, static, or prototype).

由上图可知,Esitmator类定义了模型函数(model_fn)的签名(Signature),如下:


TensorFlow入门16: 创建自定义的Estimator 2_第3张图片
Signature of model_fn

前两个参数,features, labels都好理解,是输入函数input_fn batch化后的传入的参数。在编写model_fn函数时,features,和labels这两个参数必须有!其余参数,根据情况,可以有,也可以无。

第三个参数,mode,是tf.estimator.ModeKeys类,该类有三个成员,分别是:

TRAIN: training mode.

EVAL: evaluation mode.

PREDICT: inference mode

当调用 train、evaluate 或 predict 方法时,Estimator 框架会调用模型函数并将 mode 参数设置为如下表所示的值:


TensorFlow入门16: 创建自定义的Estimator 2_第4张图片

第四个参数,params,是一个字典类型的配置参数,传给模型函数所有的超参数(Hyperparameters),如feature_columns、hidden_units、n_classesd等,例如:

params={

            'feature_columns': my_feature_columns,

            # Two hidden layers of 10 nodes each.

            'hidden_units': [10, 10],

            # The model must choose between 3 classes.

            'n_classes': 3,

        }

第五个参数,config,是RunConfig类的实例,里面包含执行环境(execution environment)的一系列参数,例如,model_dir、task_id、save_checkpoints_secs等等。若model_fn函数没有config这个参数,则Esitmator对象调用默认的run_config.RunConfig()函数来执行默认的执行环境配置,如下图所示:


TensorFlow入门16: 创建自定义的Estimator 2_第5张图片

大多数情况下,第五个参数是不用在自定义的model_fn函数中实现的,所以,通常情况下,model_fn函数原型简化为:


TensorFlow入门16: 创建自定义的Estimator 2_第6张图片
常用的model_fn原型

模型函数的签名分析完毕,下一节,将详细分析鸢尾花的自定义模型函数,并总结模型函数、输入函数以及特征列的编写方法。

你可能感兴趣的:(TensorFlow入门16: 创建自定义的Estimator 2)