转自:
https://stackoverflow.com/questions/46428604/how-to-implement-early-stopping-in-tensorflow
ValidationMonitor被标记为已弃用。不推荐。但你仍然可以使用它。这是一个如何创建一个的例子:
validation_monitor = monitors.ValidationMonitor(
input_fn=functools.partial(input_fn, subset="evaluation"),
eval_steps=128,
every_n_steps=88,
early_stopping_metric="accuracy",
early_stopping_rounds = 1000
)
你可以自己实现,这里是我的实现:
if (loss_value < self.best_loss):
self.stopping_step = 0
self.best_loss = loss_value
else:
self.stopping_step += 1
if self.stopping_step >= FLAGS.early_stopping_step:
self.should_stop = True
print("Early stopping is trigger at step: {} loss:{}".format(global_step,loss_value))
run_context.request_stop()
这是我早期停止的实施你可以适应它:
早期停止可以在训练过程的某些阶段应用,例如在每个时期结束时。特别; 就我而言; 我监测每个时期的测试(验证)损失,并且在20
时期(self.require_improvement= 20
)之后测试损失没有改善之后,训练被中断。
您可以将最大纪元设置为10000或20000或任何您想要的(self.max_epochs = 10000
)。
self.require_improvement= 20
self.max_epochs = 10000
这是我的训练功能,我使用早期停止:
def train(self):
# training data
train_input = self.Normalize(self.x_train)
train_output = self.y_train.copy()
#===============
save_sess=self.sess # this used to compare the result of previous sess with actual one
# ===============
#costs history :
costs = []
costs_inter=[]
# =================
#for early stopping :
best_cost=1000000
stop = False
last_improvement=0
# ================
n_samples = train_input.shape[0] # size of the training set
# ===============
#train the mini_batches model using the early stopping criteria
epoch = 0
while epoch < self.max_epochs and stop == False:
#train the model on the traning set by mini batches
#suffle then split the training set to mini-batches of size self.batch_size
seq =list(range(n_samples))
random.shuffle(seq)
mini_batches = [
seq[k:k+self.batch_size]
for k in range(0,n_samples, self.batch_size)
]
avg_cost = 0. # The average cost of mini_batches
step= 0
for sample in mini_batches:
batch_x = x_train.iloc[sample, :]
batch_y =train_output.iloc[sample, :]
batch_y = np.array(batch_y).flatten()
feed_dict={self.X: batch_x,self.Y:batch_y, self.is_train:True}
_, cost,acc=self.sess.run([self.train_step, self.loss_, self.accuracy_],feed_dict=feed_dict)
avg_cost += cost *len(sample)/n_samples
print('epoch[{}] step [{}] train -- loss : {}, accuracy : {}'.format(epoch,step, avg_cost, acc))
step += 100
#cost history since the last best cost
costs_inter.append(avg_cost)
#early stopping based on the validation set/ max_steps_without_decrease of the loss value : require_improvement
if avg_cost < best_cost:
save_sess= self.sess # save session
best_cost = avg_cost
costs +=costs_inter # costs history of the validatio set
last_improvement = 0
costs_inter= []
else:
last_improvement +=1
if last_improvement > self.require_improvement:
print("No improvement found during the ( self.require_improvement) last iterations, stopping optimization.")
# Break out from the loop.
stop = True
self.sess=save_sess # restore session with the best cost
## Run validation after every epoch :
print('---------------------------------------------------------')
self.y_validation = np.array(self.y_validation).flatten()
loss_valid, acc_valid = self.sess.run([self.loss_,self.accuracy_],
feed_dict={self.X: self.x_validation, self.Y: self.y_validation,self.is_train: True})
print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}".format(epoch + 1, loss_valid, acc_valid))
print('---------------------------------------------------------')
epoch +=1
我们可以在这里恢复重要的代码:
def train(self):
...
#costs history :
costs = []
costs_inter=[]
#for early stopping :
best_cost=1000000
stop = False
last_improvement=0
#train the mini_batches model using the early stopping criteria
epoch = 0
while epoch < self.max_epochs and stop == False:
...
for sample in mini_batches:
...
#cost history since the last best cost
costs_inter.append(avg_cost)
#early stopping based on the validation set/ max_steps_without_decrease of the loss value : require_improvement
if avg_cost < best_cost:
save_sess= self.sess # save session
best_cost = avg_cost
costs +=costs_inter # costs history of the validatio set
last_improvement = 0
costs_inter= []
else:
last_improvement +=1
if last_improvement > self.require_improvement:
print("No improvement found during the ( self.require_improvement) last iterations, stopping optimization.")
# Break out from the loop.
stop = True
self.sess=save_sess # restore session with the best cost
...
epoch +=1
由于TensorFlow版本的r1.10
早期停止钩子可用于估算器API early_stopping.py
(参见github)。
例如tf.contrib.estimator.stop_if_no_decrease_hook
(见文档)