利用 tf.estimator
训练模型时需要写两个重要的函数,一个用于数据输入的函数(input_fn
),另一个用于模型创建的函数(model_fn
)。下面逐一来说明。(数据格式采用 TFRecord)。
input_fn
,(就是先把图片、标签打包成一个dataset形式)返回如下两种格式之一:tf.data.Dataset 对象:这个对象的输出必须是元组队 (features, labels),而且必须满足下一条返回格式的同等约束;
元组 (features, labels):features 以及 labels 都必须是一个张量或由张量组成的字典。
下面是一个例子:
def train(data_dir):
#遍历出图片路径列表
paths = walk_type(data_dir +r'\train\\*\\','*.bmp')
#遍历出标签列表
labels = []
for path in paths:
label = int(path[14:16])
labels.append(label)
# 图片路径列表转tensor常量
filenames = tf.constant(paths
)#
# 标签列表转tensor常量
labels = tf.constant(labels)
# tensor常量转dataset
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
#
# 此时dataset中的一个元素是(image_resized, label)
dataset = dataset.map(_parse_function)#
return dataset
def train_input_fn(data_dir, # data
params #{'learning_rate': 0.001, 'batch_size': 1, 'num_epochs': 20, 'num_channels': 32, 'use_batch_norm': False, 'bn_momentum': 0.9, 'margin': 0.5, 'embedding_size': 64, 'triplet_strategy': 'batch_all', 'squared': False, 'image_size': 28, 'num_labels': 10, 'train_size': 50000, 'eval_size': 10000, 'num_parallel_calls': 4, 'save_summary_steps': 50}
):
# 把data_dir数据集中的image和label打包成元组张量
dataset = img_label_to_dataset.train(data_dir # data
) #
# 打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。单位是以图片(张量)为单位,而不是byte;
dataset = dataset.shuffle(params['train_size'] # 50000
) #
# 将整个数据集重复 params['num_epochs'] 次
dataset = dataset.repeat(params['num_epochs'] # 20
) #
# 将 params['batch_size'] 个元素组合成batch
dataset = dataset.batch(params['batch_size'] # 1
) #
# 预先载入(并行运算,加快数据运算速度)
dataset = dataset.prefetch(1) #
return dataset
tf.estimator.EstimatorSpec
的完整形式是:tf.estimator.EstimatorSpec(
mode, #指定当前是处于训练、验证还是预测状态
predictions=None, #是预测的一个张量,或者是由张量组成的一个字典
loss=None, #是损失张量
train_op=None, #指定优化操作
eval_metric_ops=None, #指定各种评估度量的字典
export_outputs=None, #参数 export_outputs 只用于模型保存,描述了导出到 SavedModel 的输出格式
training_chief_hooks=None,
training_hooks=None,
scaffold=None, #是一个 tf.train.Scaffold 对象,可以在训练阶段初始化、保存等时使用。
evaluation_hooks=None,
prediction_hooks=None
)
model_fn
,返回类 tf.estimator.EstimatorSpec 的一个实例。model_fn 的完整定义形式是(函数名任取):def model_fn(
features, #从input_fn中传入
labels, #从input_fn中传入
mode, #指定训练模式,可以取 (TRAIN, EVAL, PREDICT)三者之一
params=None #是一个(可要可不要的)字典,指定其它超参数。
):
params = params or {}
loss, train_op, ... = None, None, ...
prediction_dict = ...
if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
loss = ...#必填项(损失函数)
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = ...#必填项(训练图)
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = ...#必填项(预测结果)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=prediction_dict, #预测结果
loss=loss, #损失函数
train_op=train_op, #训练图
...)
tf.estimator.TrainSpec
指定训练输入函数及相关参数。该类的完整形式是:train_spec = tf.estimator.TrainSpec(
input_fn, #提供训练时的输入数据
max_steps, #指定总共训练多少步
hooks #一个 tf.train.SessionRunHook 对象,用来配置分布式训练等参数。
)
tf.estimator.EvalSpec
指定验证输入函数及相关参数。该类的完整形式是:EvalSpec = tf.estimator.EvalSpec(
input_fn, #用来提供验证时的输入数据
steps=100, #指定总共验证多少步(一般设定为 None 即可)
name=None,
hooks=None, #用来配置分布式训练等参数
exporters=None, #Exporter 迭代器,会参与到每次的模型验证
start_delay_secs=120, #指定多少秒之后开始模型验证
throttle_secs=600 #指定多少秒之后重新开始新一轮模型验证
)
tf.estimator.Estimator
定义 Estimator 实例 estimator。类 Estimator 的完整形式是:estimator = tf.estimator.Estimator(
model_fn, #模型函数
model_dir=None, #训练时模型保存的路径
config=None, #tf.estimator.RunConfig 的配置对象
params=None, #传入 model_fn 的超参数字典
warm_start_from=None #或者是一个预训练文件的路径,或者是一个 tf.estimator.WarmStartSettings 对象,用于完整的配置热启动参数
)
# Train the Model.
Estimator.train(
input_fn = lambda: train_input_fn(args.data_dir,
params
))
# Evaluate the model.
eval_result = classifier.evaluate(
input_fn = lambda: train_input_fn(args.data_dir,
params
))
predictions = classifier.predict(
input_fn=lambda:train_input_fn(args.data_dir,
params
))
tf.estimator.train_and_evaluate
启动训练和验证过程。该函数的完整形式是:tf.estimator.train_and_evaluate(
estimator, #tf.estimator.Estimator 对象,用于指定模型函数以及其它相关参数
train_spec, #tf.estimator.TrainSpec 对象,用于指定训练的输入函数以及其它参数
eval_spec # tf.estimator.EvalSpec 对象,用于指定验证的输入函数以及其它参数
)