tf.estimator API技术手册(8)——DNNClassifier(深度神经网络分类器)

tf.estimator API技术手册(8)——DNNClassifier(深度神经网络分类器)

  • (一)简 介
  • (二)初始化
  • (三)属 性(Properties)
  • (四)方 法(Methods)
    • (1)evaluate(评估)
    • (2)predict(预测)
    • (3)train(训练)

(一)简 介

继承自Estimator,定义在tensorflow/python/estimator/canned/dnn.py中,用来建立深度神经网络模型。示例如下:

categorical_feature_a = categorical_column_with_hash_bucket(...)
categorical_feature_b = categorical_column_with_hash_bucket(...)

categorical_feature_a_emb = embedding_column(
    categorical_column=categorical_feature_a, ...)
categorical_feature_b_emb = embedding_column(
    categorical_column=categorical_feature_b, ...)

# 有三个隐层,结点数分别为1024,512,256个
estimator = DNNClassifier(
    feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
    hidden_units=[1024, 512, 256])

# 创建训练数据输入函数
def input_fn_train: # returns x, y
  pass
estimator.train(input_fn=input_fn_train, steps=100)
# 创建评估数据输入函数
def input_fn_eval: 
  pass
metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
def input_fn_predict: # returns x, None
  pass
predictions = estimator.predict(input_fn=input_fn_predict)

(二)初始化

__init__(
    hidden_units,
    feature_columns,
    model_dir=None,
    n_classes=2,
    weight_column=None,
    label_vocabulary=None,
    optimizer='Adagrad',
    activation_fn=tf.nn.relu,
    dropout=None,
    input_layer_partitioner=None,
    config=None,
    warm_start_from=None,
    loss_reduction=losses.Reduction.SUM,
    batch_norm=False
)

参数如下:

  • hidden_units:
    设置隐层层数和每一层的结点数,如[64, 32]代表第一隐层有64个节点,第二隐层有32个节点,所有的隐层都是全连接的。

  • feature_columns:
    特征列

  • model_dir:
    保存模型的路径

  • n_classes:
    标签的种类,默认为2

  • weight_column:
    由 tf.feature_column.numeric_column创建的一个字符串或者数字列用来呈现特征列。它将会被乘以example的训练损失。

  • label_vocabulary:
    一个字符串列表用来呈现可能的标签取值,如果给出,则必须为字符型,如果没有给出,则会默认编码为整型,为{0, 1,…, n_classes-1} 。

  • optimizer:
    选择优化器,默认使用Adagrad optimizer,激活函数为tf.nn.relu。

  • input_layer_partitioner:
    输入层分割器,min_max_variable_partitioner和min_slice_size默认为64 << 20

  • config:
    一个运行配置对象,用来配置运行时间。

  • warm_start_from:
    A string filepath to a checkpoint to warm-start from, or a WarmStartSettings object to fully configure warm-starting. If the string filepath is provided instead of a WarmStartSettings, then all weights are warm-started, and it is assumed that vocabularies and Tensor names are unchanged.

  • loss_reduction:
    定义损失函数,默认为SUM方法

  • batch_norm:
    是否要在每个隐层之后使用批量归一化。

(三)属 性(Properties)

  • config
  • model_dir
  • model_fn
    Returns the model_fn which is bound to self.params.

返回:
model_fn 附有以下标记: def model_fn(features, labels, mode, config)

(四)方 法(Methods)

(1)evaluate(评估)

evaluate(
    input_fn,
    steps=None,
    hooks=None,
    checkpoint_path=None,
    name=None
)

评估函数,使用input_fn给出的评估数据评估训练好的模型,参数列表如下:

  • input_fn:
    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。
  • checkpoint_path:
    用来保存训练好的模型
  • name:
    如果用户需要在不同的数据集上运行多个评价,如训练集和测试集,则为要进行评估的名称,不同的评估度量被保存在单独的文件夹中,并分别出现在tensorboard中。

(2)predict(预测)

predict(
   input_fn,
   predict_keys=None,
   hooks=None,
   checkpoint_path=None,
   yield_single_examples=True
)

使用训练好的模型对新实例进行预测,以下为参数列表:

  • input_fn:
    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。

  • predict_keys:
    预测函数最终会返回一系列的结果,但我们可以有选择地让其输出,可供选择的keys列表为[‘logits’, ‘logistic’, ‘probabilities’, ‘class_ids’, ‘classes’],如果不指定的话,默认返回所有值。

  • hooks:
    tf.train.SessionRunHook的子类实例列表,在预测调用中用于传回。

  • checkpoint_path:
    训练好的模型的目录

  • yield_single_examples:
    可以选择False或是True,如果选择False,由model_fn返回整个批次,而不是将批次分解为单个元素。当model_fn返回的一些的张量的第一维度和批处理数量不相等时,这个功能是很用的。

(3)train(训练)

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

用于训练模型,以下为参数列表:

  • input_fn:
    一个用来构造用于评估的数据的函数,这个函数应该构造和返回如下的值:一个tf.data.Dataset对象或者一个包含 (features, labels)的元组,它们应当满足model_fn函数对输入数据的要求,在后面的实例中我们会详细介绍。

  • hooks:
    tf.train.SessionRunHook的子类实例列表,在预测调用中用于传回。

  • steps:
    模型训练的次数,如果不指定,则会一直训练知道input_fn传回的数据消耗完为止。如果你不想要增量表现,就设置max_steps来替代,注意设置了steps,max_steps必须为None,设置了max_steps,steps必须为None。

  • max_steps:
    模型训练的总次数,注意设置了steps,max_steps必须为None,设置了max_steps,steps必须为None。

  • saving_listeners:
    CheckpointSaverListener对象的列表,用于在检查点保存之前或之后立即运行的回调。

你可能感兴趣的:(tf.estimator,API技术手册)