上节分析,我们知道创建自定义的Estimator,其主要工作就是编写自定义的模型函数。
由于模型函数是传给Estimator使用的,所以,打开~\Anaconda3\Lib\site-packages\tensorflow\python\estimator文件夹下estimator.py文件,可以看到对模型函数的要求:
细节,全在源代码中,建议大家一边阅读源代码,一边看本文
由上图可以看到:`Estimator` 对象封装了神经网络模型(model),这个模型由model_fn的参数指定(specify),model_fn会返回必要的操作(ops),用于训练、评估和预测。
model_fn的函数签名,如下图所示:
*基础知识:函数签名( function signature), 又叫做 type signature, 或 method signature,定义了 函数 或 方法.的输入输出,主要包括:
1)parameters and their types
2)a return value and type
3)exceptions that might be thrown or passed back
4)information about the availability of the method in an object-oriented program (such as the keywords public, static, or prototype).
由上图可知,Esitmator类定义了模型函数(model_fn)的签名(Signature),如下:
前两个参数,features, labels都好理解,是输入函数input_fn batch化后的传入的参数。在编写model_fn函数时,features,和labels这两个参数必须有!其余参数,根据情况,可以有,也可以无。
第三个参数,mode,是tf.estimator.ModeKeys类,该类有三个成员,分别是:
TRAIN: training mode.
EVAL: evaluation mode.
PREDICT: inference mode
当调用 train、evaluate 或 predict 方法时,Estimator 框架会调用模型函数并将 mode 参数设置为如下表所示的值:
第四个参数,params,是一个字典类型的配置参数,传给模型函数所有的超参数(Hyperparameters),如feature_columns、hidden_units、n_classesd等,例如:
params={
'feature_columns': my_feature_columns,
# Two hidden layers of 10 nodes each.
'hidden_units': [10, 10],
# The model must choose between 3 classes.
'n_classes': 3,
}
第五个参数,config,是RunConfig类的实例,里面包含执行环境(execution environment)的一系列参数,例如,model_dir、task_id、save_checkpoints_secs等等。若model_fn函数没有config这个参数,则Esitmator对象调用默认的run_config.RunConfig()函数来执行默认的执行环境配置,如下图所示:
大多数情况下,第五个参数是不用在自定义的model_fn函数中实现的,所以,通常情况下,model_fn函数原型简化为:
模型函数的签名分析完毕,下一节,将详细分析鸢尾花的自定义模型函数,并总结模型函数、输入函数以及特征列的编写方法。