TensorFlow入门17: 创建自定义的Estimator 3

上节,详细分析了模型函数的签名(signature of model_fn), 本节通过分析鸢尾花的自定义模型函数(my_model() in custom_estimator.py),来总结模型函数的编写方法。

深度神经网络模型必须定义下列三个部分:

一个输入层; 一个或多个隐藏层; 一个输出层,如下图所示:


TensorFlow入门17: 创建自定义的Estimator 3_第1张图片

第一步,定义输入层,模型函数的第一行调用 tf.feature_column.input_layer,将字典类型的特征值和 feature_columns 转换为模型的输入,

tf.feature_column.input_layer(features, feature_columns, weight_collections=None, trainable=True, cols_to_vars=None)函数

第一个参数:features: 是一个字典类型,键(key)是每个特征列的名字,数值(value)是特征的数值,其类型取决于_FeatureColumn。把features变量print出来,如下图所示:


TensorFlow入门17: 创建自定义的Estimator 3_第2张图片
features

第二个参数:feature_columns:是输入到模型的,可迭代(iterable)的,包含了FeatureColumns的变量。所有的items都必须是从_DenseColumn类派生出来的类,例如numeric_column, embedding_column,bucketized_column, indicator_column等类,的实例

返回值:一个代表模型输入层的Tensor


TensorFlow入门17: 创建自定义的Estimator 3_第3张图片
输入层数据映射关系

第二步,定义隐藏层,深度神经网络必须包含一个或多个隐藏层,Layers API提供一组丰富的函数来定义所有类型的隐藏层,包括卷积层、池化层和丢弃层。对于鸢尾花,我们只需调用 tf.layers.dense 来创建隐藏层,并用 params['hidden_layers'] 定义维度。在 dense 层中,每个节点都连接到前一层中的各个节点。

tf.layers.dense(inputs,units,activation=None,use_bias=True,kernel_initializer=None,bias_initializer=tf.zeros_initializer(),kernel_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,bias_constraint=None,trainable=True,name=None,reuse=None)函数,该函数实现操作: outputs = activation(inputs.kernel + bias) ,这里activation 是activation参数传入的activation function (if not None), kernel 是weights matrix, bias是bias vector (only if use_bias is True).

第一个参数:inputs 是本层输入的Tensor,也是上一层的输出。输入层 tf.feature_column.input_layer()的输出是 net, 所以net作为inputs参数输入。在第一次迭代中,net 表示输入层。在每次循环迭代时,tf.layers.dense 使用变量 net 创建一个新层,该层将前一层的输出作为其输入。

第二个参数:units是本层神经元的个数,对于dense层来说,也是输出的个数。

第三个参数:activation是本层使用的激活函数,activation=tf.nn.relu,是指使用relu作为激活函数。若activation=None,则意味着不使用激活函数,相当于是线性激活(linear activation)。

第三步,定义输出层,继续调用 tf.layers.dense 来定义输出层,输出层不需要使用激活函数

 logits = tf.layers.dense(net, params['n_classes'], activation=None)

经过这三步,创建出来的神经网络,如下图所示:


TensorFlow入门17: 创建自定义的Estimator 3_第4张图片

创建好神经网络模型后,还剩最后一步:编写实现预测、评估和训练的分支代码,该内容在下解详述。

你可能感兴趣的:(TensorFlow入门17: 创建自定义的Estimator 3)