tf.estimator的用法

tf.estimator的用法

利用 tf.estimator 训练模型时需要写两个重要的函数,一个用于数据输入的函数(input_fn),另一个用于模型创建的函数(model_fn)。下面逐一来说明。(数据格式采用 TFRecord)。

首先我们从输入到输出调用顺序来介绍一下大概的训练过程(完整官方文档:tf.estimator):

定义输入函数 input_fn,(就是先把图片、标签打包成一个dataset形式)返回如下两种格式之一:

tf.data.Dataset 对象:这个对象的输出必须是元组队 (features, labels),而且必须满足下一条返回格式的同等约束;
元组 (features, labels):features 以及 labels 都必须是一个张量或由张量组成的字典。
下面是一个例子:

def train(data_dir):
    #遍历出图片路径列表
    paths = walk_type(data_dir +r'\train\\*\\','*.bmp')
    #遍历出标签列表
    labels = []
    for path in paths:
        label = int(path[14:16])
        labels.append(label)
    # 图片路径列表转tensor常量
    filenames = tf.constant(paths
                            )#
    # 标签列表转tensor常量
    labels = tf.constant(labels)
    # tensor常量转dataset
    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    # 
    # 此时dataset中的一个元素是(image_resized, label)
    dataset = dataset.map(_parse_function)#
    return dataset
def train_input_fn(data_dir, # data
                   params #{'learning_rate': 0.001, 'batch_size': 1, 'num_epochs': 20, 'num_channels': 32, 'use_batch_norm': False, 'bn_momentum': 0.9, 'margin': 0.5, 'embedding_size': 64, 'triplet_strategy': 'batch_all', 'squared': False, 'image_size': 28, 'num_labels': 10, 'train_size': 50000, 'eval_size': 10000, 'num_parallel_calls': 4, 'save_summary_steps': 50}
                   ):
    # 把data_dir数据集中的image和label打包成元组张量
    dataset = img_label_to_dataset.train(data_dir # data
                                  ) #
    # 打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。单位是以图片(张量)为单位,而不是byte;
    dataset = dataset.shuffle(params['train_size'] # 50000
                              )  # 
    # 将整个数据集重复 params['num_epochs'] 次
    dataset = dataset.repeat(params['num_epochs'] # 20
                             )  # 
    # 将 params['batch_size'] 个元素组合成batch
    dataset = dataset.batch(params['batch_size'] # 1
                            ) # 
    # 预先载入(并行运算,加快数据运算速度)
    dataset = dataset.prefetch(1)  # 
    return dataset

tf.estimator.EstimatorSpec 的完整形式是:

tf.estimator.EstimatorSpec(
    mode,                       #指定当前是处于训练、验证还是预测状态
    predictions=None,           #是预测的一个张量,或者是由张量组成的一个字典
    loss=None,                  #是损失张量
    train_op=None,              #指定优化操作
    eval_metric_ops=None,       #指定各种评估度量的字典
    export_outputs=None,        #参数 export_outputs 只用于模型保存,描述了导出到 SavedModel 的输出格式
    training_chief_hooks=None,
    training_hooks=None,
    scaffold=None,              #是一个 tf.train.Scaffold 对象,可以在训练阶段初始化、保存等时使用。
    evaluation_hooks=None,
    prediction_hooks=None
                           )

定义模型函数 model_fn,返回类 tf.estimator.EstimatorSpec 的一个实例。model_fn 的完整定义形式是(函数名任取):

def model_fn(
features,    #从input_fn中传入
labels,      #从input_fn中传入
mode,        #指定训练模式,可以取 (TRAIN, EVAL, PREDICT)三者之一
params=None  #是一个(可要可不要的)字典,指定其它超参数。
            ):
    params = params or {}
    loss, train_op, ... = None, None, ...
    prediction_dict = ...
    if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
        loss = ...#必填项(损失函数)
    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = ...#必填项(训练图)
    if mode == tf.estimator.ModeKeys.PREDICT:   
        predictions = ...#必填项(预测结果)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=prediction_dict,  #预测结果
        loss=loss,                    #损失函数
        train_op=train_op,            #训练图
        ...)

使用 tf.estimator.TrainSpec 指定训练输入函数及相关参数。该类的完整形式是:

train_spec = tf.estimator.TrainSpec(
input_fn,    #提供训练时的输入数据
max_steps,   #指定总共训练多少步
hooks        #一个 tf.train.SessionRunHook 对象,用来配置分布式训练等参数。
                                     )

使用 tf.estimator.EvalSpec 指定验证输入函数及相关参数。该类的完整形式是:

EvalSpec = tf.estimator.EvalSpec(
    input_fn,              #用来提供验证时的输入数据
    steps=100,             #指定总共验证多少步(一般设定为 None 即可)
    name=None,             
    hooks=None,            #用来配置分布式训练等参数           
    exporters=None,        #Exporter 迭代器,会参与到每次的模型验证
    start_delay_secs=120,  #指定多少秒之后开始模型验证
    throttle_secs=600      #指定多少秒之后重新开始新一轮模型验证 
                     )

使用 tf.estimator.Estimator 定义 Estimator 实例 estimator。类 Estimator 的完整形式是:

estimator = tf.estimator.Estimator(
model_fn,              #模型函数
model_dir=None,        #训练时模型保存的路径
config=None,           #tf.estimator.RunConfig 的配置对象
params=None,           #传入 model_fn 的超参数字典
warm_start_from=None   #或者是一个预训练文件的路径,或者是一个 tf.estimator.WarmStartSettings 对象,用于完整的配置热启动参数
                                   )

下面就有两种调用方法:

  1. 在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源

    estimator.train()

    # Train the Model.
    Estimator.train(
              input_fn = lambda: train_input_fn(args.data_dir,
                                                params 
                                                ))
    

    estimator.evaluate()

    # Evaluate the model.
    eval_result = classifier.evaluate(
         input_fn = lambda: train_input_fn(args.data_dir,
                                           params 
                                           ))
    

    estimator.predict()

    predictions = classifier.predict(
        input_fn=lambda:train_input_fn(args.data_dir,
                                       params 
                                       ))
    
  2. 使用 tf.estimator.train_and_evaluate 启动训练和验证过程。该函数的完整形式是:

    tf.estimator.train_and_evaluate(
    estimator,  #tf.estimator.Estimator 对象,用于指定模型函数以及其它相关参数
    train_spec, #tf.estimator.TrainSpec 对象,用于指定训练的输入函数以及其它参数   
    eval_spec   # tf.estimator.EvalSpec 对象,用于指定验证的输入函数以及其它参数
            )
    

你可能感兴趣的:(tf.estimator的用法)