在 这篇博客 中训练CNN的时候,即便是对fc层加了dropout,对loss加了L2正则化,依然出现了过拟合的情况(如下图所示),于是开始尝试用early stop解决拟合问题。
(训练集的loss在下降而测试集的loss却在5k步左右开始上升,说明过拟合了)
想要实现ES,首先需要知道loss的值(以便根据loss值在xx次迭代内的变化决定是否需要停止training),Tensorflwo中提供了hook来帮助我们从graph中“钩取”出想要的值。
这里用到的hook是 tf.contrib.estimator.stop_if_no_decrease_hook ,它的源码可以 看这里
tf.contrib.estimator.stop_if_no_decrease_hook(
estimator,
metric_name,
max_steps_without_decrease,
eval_dir=None,
min_steps=0,
run_every_secs=60,
run_every_steps=None
)
其中metric_name用来指明监控的变量(比如loss或者accuracy),在这篇博文中的run_mnist()代码如下:
def run_mnist(params):
model_helpers.apply_clean(params) # 清空model_dir文件夹下的旧文件
#实例化estimator
paramsdic = params.flag_values_dict()
model = tf.estimator.Estimator(model_fn=cnn_model_fn,model_dir=params.model_dir,params=paramsdic) #Estimator的构造函数会把params传给model_fn
#为啥不能params=params??因为传入的params是一个类!!!absl.flags._flagvalues.FlagValues类,需要调用函数flag_values_dict()将他的属性转化成dic才能被传入model_fn
#没转化成dic时用params.dropout_rate 代表取出属性
#转化成dic后用params['dropout_rate'] #代表取出key对应的value
#实例化hooks(用于监控台输出程序运行的记录日志,记录哪些量由tensor_to_log字典给出)而tensorboard的图似乎和hook没关系?
tensor_to_log={'prob':'softmax_tensor'}#打印prob,其值来源于softmax_tensor
train_hooks = hooks_helper.get_train_hooks(name_list=params.hooks,model_dir=params.model_dir,)#tensors_to_log=tensor_to_log)
os.makedirs(model.eval_dir())
train_hoooks_for_earlyStoping = stop_if_no_decrease_hook(model,eval_dir=model.eval_dir(),metric_name='accuracy',max_steps_without_decrease=1000,min_steps=100)
#必须使用loss而不是eval_loss,因为train里自动记录的是名字为‘loss’的值
#input_fn函数
def train_input_fn():#这里虽然返回的是一个ds但是实际上这个是被zip(feature,label)的ds,可以直接被parse成feature,label [也就是 model.train中需要input_fn返回的形式]
ds = dataset.train(params.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(params.batch_size)
ds = ds.repeat(params.epochs_between_evals)
return ds
def eval_input_fn():
return dataset.test(params.data_dir).batch(
params.batch_size).make_one_shot_iterator().get_next()
#每次返回一个(fea,lab)对??
#为啥eval的input返回的是迭代器而train的input返回的是整个的dataset??
#train和eval
for i in range(params.train_epochs // params.epochs_between_evals):
# tf.estimator.train_and_evaluate(model,train_spec=tf.estimator.TrainSpec(train_input_fn,hooks=[train_hoooks_for_earlyStoping]),
# eval_spec=tf.estimator.EvalSpec(eval_input_fn))
model.train(input_fn=train_input_fn,hooks=[train_hoooks_for_earlyStoping])# 如果这里参数传入了 hooks=train_hooks 那么model_fn中的train就要把注释的几个identity解开
if train_hoooks_for_earlyStoping.stopFlag == True :
break
eval_results = model.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(params.stop_threshold,
eval_results['accuracy']):
break
其中 定义ES的hook语句如下所示,他返回的是一个实例化的hook:
train_hoooks_for_earlyStoping = stop_if_no_decrease_hook(model,eval_dir=model.eval_dir(),metric_name='accuracy',max_steps_without_decrease=1000,min_steps=100)
由于只是将这个hook传给了model.train(),所以只能让train停止,此后会接着执行model.eval()以及进入下一次循环,因此并没有真正起到EarlyStopping的作用(整个程序停止,这里只是让model.train()停止),所以需要对stop_if_no_decrease_hook的源码进行修改,为这个类增加一个属性,用来标识是否开始ES:
修改的源码如下所示:
class _StopOnPredicateHook(session_run_hook.SessionRunHook):
"""Hook that requests stop when `should_stop_fn` returns `True`."""
def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None):
if not callable(should_stop_fn):
raise TypeError('`should_stop_fn` must be callable.')
#增加这个tag !!!!!
self.stopFlag = False
#增加这个tag !!!!!
self._should_stop_fn = should_stop_fn
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_secs=run_every_secs, every_steps=run_every_steps)
self._global_step_tensor = None
self._stop_var = None
self._stop_op = None
def begin(self):
self._global_step_tensor = training_util.get_global_step()
self._stop_var = _get_or_create_stop_var()
self._stop_op = state_ops.assign(self._stop_var, True)
def before_run(self, run_context):
del run_context
return session_run_hook.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
global_step = run_values.results
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
if self._should_stop_fn():
self.stopFlag = True
tf_logging.info('Requesting early stopping at global step %d',
global_step)
run_context.session.run(self._stop_op)
run_context.request_stop()
class _CheckForStoppingHook(session_run_hook.SessionRunHook):
"""Hook that requests stop if stop is requested by `_StopOnPredicateHook`."""
def __init__(self):
self._stop_var = None
#增加这个tag !!!!!
self.stopFlag = False
#增加这个tag !!!!!
def begin(self):
self._stop_var = _get_or_create_stop_var()
def before_run(self, run_context):
del run_context
return session_run_hook.SessionRunArgs(self._stop_var)
def after_run(self, run_context, run_values):
should_early_stop = run_values.results
if should_early_stop:
self.stopFlag = True
tf_logging.info('Early stopping requested, suspending run.')
run_context.request_stop()
这样的话在loop的循环中:
for i in range(params.train_epochs // params.epochs_between_evals):
# tf.estimator.train_and_evaluate(model,train_spec=tf.estimator.TrainSpec(train_input_fn,hooks=[train_hoooks_for_earlyStoping]),
# eval_spec=tf.estimator.EvalSpec(eval_input_fn))
model.train(input_fn=train_input_fn,hooks=[train_hoooks_for_earlyStoping])# 如果这里参数传入了 hooks=train_hooks 那么model_fn中的train就要把注释的几个identity解开
if train_hoooks_for_earlyStoping.stopFlag == True :
break
eval_results = model.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(params.stop_threshold,
eval_results['accuracy']):
break
model.train()之后立刻验证新增加的tag的值,如果是true,说明上面的model.train()不是正常结束的,而是由于ES结束的,此时立刻跳出循环,结束整个训练和测试loop。
最后得到的结果如下:
可以看到大约4k步之后,hook检测到eval的loss很久没有降低了,于是进行了ES(没有ES时会迭代25k步左右,见上图)
参考链接:
Implement early stopping in tf.estimator.DNNRegressor using the available training hooks
Early stopping with tf.estimator, how?