evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None)
使用验证集 input_fn 对 model 进行验证。
对于每一步,执行 input_fn(返回数据集的一个 batch)。
steps
个 batch,或者input_fn
抛出了出界异常(OutOfRangeError
或 StopIteration
)input_fn
:此函数构造出验证所需的输入数据,需要返回以下结构之一:
tf.data.Dataset
对象:Dataset
对象的输出必须是一个元组 (features, labels),和下面的规格相同。features
是一个 Tensor
或者字典(a dictionary of string feature name to Tensor
)。labels
是一个 Tensor
或者字典(a dictionary of string label name to Tensor
)。features
和 labels
都被 model_fn
所使用(model_fn
是 tf.estimator.Estimator
的构造函数的参数之一)。他们应该满足 model_fn
输入端的需求。steps
:验证模型的步数。如果是 None
,则一直验证下去,直至input_fn
抛出了出界异常。
hooks
:SessionRunHook
子类实例的 list。作为验证的回调函数。
checkpoint_path
:特定检查点的路径。如果是 None
,则默认为 model_dir
中最近的检查点(model_dir
是 tf.estimator.Estimator
的构造函数的参数之一)
name
:验证的名字。使用者可以针对不同的数据集运行多个验证操作,比如训练集 vs 测试集。不同验证的结果被保存在不同的文件夹中,且分别出现在 tensorboard 中。
返回一个字典,包括 model_fn
中指定的评价指标、global_step
。
ValueError
:如果 step
小于等于0
ValueError
:如果 model_dir
指定的模型没有被训练,或者指定的 checkpoint_path
为空。
先定义Estimator
:
cnn_model = tf.estimator.Estimator(
model_fn=model_function, model_dir=save_model_path
)
然后进行训练:
cnn_model.train(
input_fn=lambda: get_train_batch(train_file_path), steps=steps_per_eval)
最后进行验证:
evaluate_results = cnn_model.evaluate(
input_fn=lambda: get_val_batch(val_file_path),
steps=eval_steps_per_train_cycle)
其中,数据是从 tfrecords 中读取的:
def get_train_batch(data_dir, batch_size=conf.batch_size, set_name='train', use_distortion=True):
dataset = DataSet(data_dir, set_name, use_distortion)
return dataset.get_batch(data_dir, batch_size)
def get_val_batch(data_dir, batch_size=conf.batch_size, set_name='val', use_distortion=False):
dataset = DataSet(data_dir, set_name, use_distortion)
return dataset.get_batch(data_dir, batch_size)
class DataSet(object):
....
def get_batch(self, file_path, batch_size):
"""
:param batch_size: train, val, test batch_size is different
:param file_path:
:return:
"""
files = tf.data.Dataset.list_files(file_path)
dataset = files.apply(
tf.contrib.data.parallel_interleave(tf.data.TFRecordDataset, cycle_length=conf.num_parallel_readers,
sloppy=True))
if self.set_name == 'train':
dataset = dataset.repeat(conf.train_epochs)
dataset = dataset.shuffle(conf.shuffle_buffer_size)
dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=self.parser_single_img, batch_size=batch_size,
num_parallel_batches=conf.num_parallel_batches))
dataset = dataset.prefetch(conf.batch_size)
iterator = dataset.make_one_shot_iterator()
img_batch, label_batch = iterator.get_next()
return img_batch, label_batch