TensorFlow入门18: 创建自定义的Estimator 4

上节,详细分析了在自定义模型函数中,创建神经网络的三个步骤:创建输入层、创建隐藏层、创建输出层。本节主要介绍自定义模型函数的最后一步:编写实现预测、评估和训练的分支代码

回忆一下:《TensorFlow入门16: 创建自定义的Estimator 2》

                  1,Model_fn的返回值是: tf.estimator.EstimatorSpec。

                  2,Estimator对象的三个方法train、evaluate、predict都会调用model_fn给Estimator传参数。

                  3,当Estimator对象调用 train、evaluate 或 predict 方法时,Estimator 对象会在调用模型函数前,将 mode 参数设置为对应的值:ModeKeys.TRAIN、ModeKeys.EVAL、ModeKeys.PREDICT。

由此,model_fn函数创建好神经网络后,检测mode值,根据不同的mode,实现对应的代码,并返回: tf.estimator.EstimatorSpec,具体的实现,参考下图:


TensorFlow入门18: 创建自定义的Estimator 4_第1张图片
model_fn的代码实现


完成model_fn函数编写后,回到main函数,可以发现,只有创建classifier对象的代码,略有不同,其余代码一模一样,如下图所示


TensorFlow入门18: 创建自定义的Estimator 4_第2张图片
创建classifier对象代码比较

你可能感兴趣的:(TensorFlow入门18: 创建自定义的Estimator 4)