tf.estimator.Estimator是TF比较高级的接口。
最近在使用bert预训练模型的时候用到了tf.estimator.Estimator。使用该接口的时候需要开发者完成的工作比较少,一共3个步骤:
第一步,设置input_fun,第二步,设置model_fun,第三步,开始训练。
第一步的input_fun完成的功能是数据的输入准备工作,比如读取一个tfrecord文件,然后解析里面的内容,返回dataset;或者读取音频、图像等数据,返回相应的结果,目前来说返回的结果为dataset格式比较好。
第二步的model_fun完成的功能有:创建模型(输入feature,输出predict这种),设置loss,设置优化器,返回结果是tf.estimator.EstimatorSpec。(后续会说明tf.estimator.EstimatorSpec是什么,怎么设置)
第三步的开始训练是:参数准备(比如学习率什么的,就是上面的步骤1-2中需要用到的参数),设置config(用于训练模型是指定模型的保存路径,多长时间保存一次模型,使用GPU的一些情况),开始根据情况调用estimator.train 和 estimator.evaluate 或者 estimator.predict。
def input_fn(filenames, batch_size=32, num_epochs=None, perform_shuffle=False):
"""
每次调用,从TFRecord文件中读取一个大小为batch_size的batch
Args:
filenames: TFRecord文件
batch_size: batch_size大小
num_epochs: 将TFRecord中的数据重复几遍,如果是None,则永远循环读取不会停止
perform_shuffle: 是否乱序
Returns:
tensor格式的,一个batch的数据
"""
def _parse_fn(record):
features = {
"label": tf.FixedLenFeature([], tf.int64),
"image": tf.FixedLenFeature([], tf.string),
}
parsed = tf.parse_single_example(record, features)
# image
image = tf.decode_raw(parsed["image"], tf.uint8)
image = tf.reshape(image, [28, 28])
# label
label = tf.cast(parsed["label"], tf.int64)
return {"image": image}, label
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
def model_fn(features, labels, mode, params):
"""
:param features:
:param labels:
:param mode: 指定训练、验证和测试三种模式
tf.estimator.ModeKeys.TRAIN tf.estimator.ModeKeys.EVAL tf.estimator.ModeKeys.PREDICT
:param params: 包含学习率等超参数的设计
:return:
"""
# step1: 构建模型
logits = create_model(features)
predict = tf.nn.softmax(logits, axis=-1)
# step2: 构建loss、optimization等
loss = get_loss(logits, labels)
train_op = tf.train.GradientDescentOptimizer(params['lr']).minimize(loss)
# step3: 根据mode,构建不同情况下的tf.estimator.EstimatorSpec
# For mode == ModeKeys.TRAIN: 需要的参数是 loss and train_op.
# For mode == ModeKeys.EVAL: 需要的参数是 loss.
# For mode == ModeKeys.PREDICT: 需要的参数是 predictions.
if mode == tf.estimator.ModeKeys.TRAIN:
# logging_hook是模型训练/测试的工具,主要执行特定的任务,如判断是否需要停止训练的EarlyStopping,
# 改变学习速率的LearningRateScheduler,共性就是在每个step开始/结束或者每个epoch开始/结束时需要执行某个操作。
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
training_hooks=[logging_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
eval_metric_ops=eval_metrics)
else:
output_spec = tf.estimator.EstimatorSpec(
mode=mode,
predictions={"probabilities": predict})
return output_spec
def main_():
# 1. 设置超参数
params = {'lr', 0.0001}
# 2. 设置config,用于控制模型保存的位置,多久保存一次等
session_config = tf.ConfigProto(log_device_placement=False,
inter_op_parallelism_threads=0,
intra_op_parallelism_threads=0,
allow_soft_placement=True)
run_config = tf.estimator.RunConfig(model_dir=model_output_dir,
save_checkpoints_steps=5000,
keep_checkpoint_max=3,
session_config=session_config)
# 3. 开始训练
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
params=params)
if do_train:
train_input_fn = input_fun(...)
estimator.train(input_fn=train_input_fn)
elif do_eval:
eval_input_fn = input_fun(...)
estimator.train(input_fn=eval_input_fn)
else:
predict_input_fn = input_fun(...)
estimator.train(input_fn=predict_input_fn)
===未完待续===
之后会更新关于hook等如何设置
参考文献:
https://zhuanlan.zhihu.com/p/129018863
https://zhuanlan.zhihu.com/p/106400162
https://www.jianshu.com/p/5495f87107e7