[tensorflow]tf.estimator.Estimator构建tensorflow模型

目录

一、Estimator简介

二、数据集

三、定义特征列

四、estimator创建模型

五、模型训练、评估和预测

六、模型保存和恢复


一、Estimator简介

Estimator是TensorFlow对完整模型的高级表示。Tensorflow提供一个包含多个API层的编程堆栈:

[tensorflow]tf.estimator.Estimator构建tensorflow模型_第1张图片

Estimator封装了操作:训练、评估、预测、导出以供使用。

二、数据集

通过tf.data模块,构建输入管道,将数据传送到模型中。tf.data模块返回的是Dataset对象,每个Dataset包含(feature_dict, labels)对。

https://blog.csdn.net/woniu201411/article/details/89249689

三、定义特征列

特征列视为原始数据和Estimator之间的媒介。要创建特征列,需要调用tf.feature_column模块的函数。

[tensorflow]tf.estimator.Estimator构建tensorflow模型_第2张图片

1、数值列

tf.feature_column.numeric_column将具有默认数据类型(tf.float32)的数值指定为模型输入。

2、分桶列

tf.feature_column.bucketized_column将数字列根据数值范围分为不同的类别(为模型中加入非线性特征,提高模型的表达能力)。

3、分类标识列

tf.feature_column.categorical_column_with_identity将每个分桶表示一个唯一整数,模型可以在分类标识列中学习每个类别各自的权重。

4、分类词汇列

tf.feature_column.categorical_column_with_vocabulary_list将字符串表示为独热矢量,根据明确的词汇表将每个字符串映射到一个整数。

tf.feature_column.categorical_column_with_vocabulary_file将字符串表示为独热矢量,根据文件中的词汇将每个字符串映射到一个整数。

5、经过哈希处理的列

tf.feature_column.categorical_column_with_hash_bucket将类别数量非常大的特征列,模型会计算输入的哈希值,然后使用模运算符将其置于其中一个hash_bucket_size类别中。

6、特征组合列

tf.feature_column.categorical_column_with_hash_bucket将任意分类列进行组合,但仅构建hash_bucket_size参数所请求的类别数量。

7、指标列和嵌入列

指标列(tf.feature_column.indicator_column)和嵌入列(tf.feature_column.embedding_column)将分类列视为输入。

四、estimator创建模型

预创建的Estimator是tf.estimator.Estimator基类的子类,而自定义的Estimator是tf.estimator.Estimator的实例。两者的使用区别在于,预创建的Estimator已有模型函数,而自定义的Estimator需要自己编写模型函数。

[tensorflow]tf.estimator.Estimator构建tensorflow模型_第3张图片

 

1、预创建的estimator

Tensorflow提供了三个预创建的分类器Estimator(Estimator代表一个完整的模型):

tf.estimator.DNNClassifier 多类别分类的深度模型

tf.estimator.LinearClassifier 基于线性模型的分类器

tf.estimator.DNNLinearCombinedClassifier 宽度和深度模型

2、自定义的estimator

定义模型函数,模型参数具有以下参数

def my_model_fn(features, labels, mode, params):

features、labels是从输入函数中返回的特征和标签批次。

model表示调用程序是请求训练、预测还是评估。tf.estimator.ModeKeys

params是调用程序将params传递给Estimator的构造函数,转而又传递给model_fn.例如:

classifier = tf.estimator.Estimator(
    model_fn=my_model,
    params={
        'feature_columns': my_feature_columns,
        # Two hidden layers of 10 nodes each.
        'hidden_units': [10, 10],
        # The model must choose between 3 classes.
        'n_classes': 3,
    })

模型-输入层:将特征字典和feature_columns转换为模型的输入

模型隐藏层:tf.layers提供所有类型的隐藏层,包括卷积层、池化层和丢弃层。

模型输出层:tf.layers.dense定义输出层。使用tf.nn.softmax将分数转换为概率。

五、模型训练、评估和预测

Estimator方法 Estimator模式
train() ModeKeys.TRAIN
evaluate() ModeKeys.EVAL
predict() ModeKeys.PREDICT

 

1、模型训练

classifer.train(

    input_fn = lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),

    max_steps=args.train_steps)

Estimator会调用模型函数并将mode设为ModeKeys.TRAIN

input_fn:输入数据。将input_fn调用封装在lambda中以获取参数,提供一个不采用任何参数的输入函数。

max_steps:模型训练的最多步数。

在my_model_fn中,定义损失函数和优化损失函数的方法:

# Calculate Loss (for both TRAIN and EVAL modes)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)  # 多类别分类问题,采用softmax交叉熵用作损失函数

    # Configure the Training Op (for TRAIN mode)
    # 采用随机梯度下降法优化损失函数,学习速率为0.001
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
        train_op = optimizer.minimize(
            loss=loss,
            global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

2、模型评估

# Evaluate the model.
eval_result = classifier.evaluate(
    input_fn=lambda:iris_data.eval_input_fn(test_x, test_y,args.batch_size))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

Estimator会调用模型函数并将mode设为ModeKeys.EVAL。模型函数必须返回一个包含模型损失和一个或多个指标(可选)的tf.estimator.EstimatorSpec.

使用tf.estrics计算常用指标

# Add evaluation metrics (for EVAL mode), 准确率指标
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels,
predictions=predictions["classes"])}
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

3、模型预测

predictions = classifier.predict(
    input_fn=lambda:iris_data.eval_input_fn(predict_x,labels=None,batch_size=args.batch_size))

调用Estimator的predict方法,则model_fn会收到mode=ModeKeys.PREDICT,模型函数返回一个包含预测的tf.estimator.EstimatorSpec.

 predictions = {
        # Generate predictions (for PREDICT and EVAL mode)
        "classes": tf.argmax(input=logits, axis=1),
        # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
        # `logging_hook`.
        "probabilities": tf.nn.softmax(logits, name="softmax_tensor"),
        # Generate image feature vector
        "feature": dense
    }

if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

predictions存储的是三个键值对:

classes:存储的是模型对此样本预测的最有可能的类别id;

probabilities:存储的是样本属于各个类别的概率值;

features:存储的是样本的特征向量(倒数第二层)。

六、模型保存和恢复

Estimator自动将模型信息写入磁盘:检查点,训练期间所创建的模型版本;事件文件,包含TensorBoard用于创建可视化图表的信息。在Estimator的构造函数model_dir参数中定义模型保存路径。

模型保存:

[tensorflow]tf.estimator.Estimator构建tensorflow模型_第4张图片

如图所示,第一次调用train会将检查点和事件文件添加到model_dir目录中。

默认情况下,Estimator按照以下时间安排将检查点保存到model_dir中:每10分钟(600秒)写入一个检查点;在train方法开始(第一次迭代)和完成(最后一次迭代)时写入一个检查点;在目录中保留5个最近写入的检查点。

通过tf.estimator.RunConfig对默认保存时间更改:

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)

模型恢复:

[tensorflow]tf.estimator.Estimator构建tensorflow模型_第5张图片

第一次调用estimator的train方法时,TensorFlow将第一个检查点保存到model_dir中,随后每次调用Estimator的train、evaluate或predict方法时,都会:Estimator运行model_fn构建模型图;Estimator根据最近写入的检查点中存储的数据来初始化新模型的权重。

通过检查点恢复模型的状态仅在模型和检查点兼容时可行。例如,训练一个DNNClassifier estimator,它包含2个隐藏层且每层都有10个节点,在训练后,将每个隐藏层中的神经元数量从10改为20,然后重新训练模型,由于检查点中的状态与模型不兼容,会出现错误:

does not match the shape stored in checkpoint.

 

参考资料:

https://www.tensorflow.org/guide/premade_estimators#evaluate_the_trained_model

https://www.tensorflow.org/guide/custom_estimators

https://www.tensorflow.org/guide/checkpoints

https://www.tensorflow.org/guide/feature_columns

 

你可能感兴趣的:(Tensorflow)