AllenNLP源码学习——trainer

trainer是模型训练的中枢,它的内部控制着模型训练的各个组件,如model,iterator,datasets,num_epochs,optimizer,读取与保存,打印输出,早停,summary_interval(多少个epoch用tensorboard记录一次),should_log_learning_rate(是否记录学习率变化)等。

一个epoch

全部数据训练一轮结束后,输入一遍验证集。

训练: 一个batch送入model,得到一个loss,以最近更新的QANet为例,用的nll_loss (它有一项参数默认为reduction=‘mean’,即一个batch的loss求均值),把所有batch的loss求和train_loss += loss.item(),求均值并返回值

metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch)
  1. get_metrics是调用模型特定的get_metrics(比如返回accuracy,f1),加上各种模型都要有的loss。
  2. 在epoch循环内的这句是为了实时记录一个epoch内的batch loss(model的accuracy,f1等也同时记录)变化到tensorboardX。
  3. 循环外的get_metrics设置了reset=True,即一个epoch结束时,返回所有样本loss的均值后,将loss值清零(model的accuracy,f1等值也清零),以便下一个epoch使用。
  4. 可见,模型中定义的def get_metrics(self, reset: bool = False),是模型向外界输出数据的端口。trainer控制iterator从datasets取数据,送入model,trainer从get_metrics获得模型的结果。

验证集: val_loss, num_batches = self._validation_loss()中进行,得到loss和train一致。

在训练过程中,控制台最下方进度条的最大值为一轮中mini-batch(iteration)的数量

保存检查点checkpoint

  1. Checkpointer类控制保存(save_checkpoint)与(restore_checkpoint)加载最近一个检查点文件(加载到cpu然后转到GPU),加载最佳的模型文件(best_model_state),包括model_state_epoch_{}.thtraining_state_epoch_{}.th
  2. 一个epoch结束(train+validation)后,保存一次
  3. model_state是model.state_dict()获得的模型参数
  4. training_states包含metric_tracker(控制早停,找到验证集上得分最高的一轮),optimizer的参数,batch_num_total(训练结束的batch数量),learning_rate_scheduler的参数。
  5. trainer的参数num_serialized_models_to_keep(默认为20)设置保存最多多少个检查点,多于20个时,从第一个开始依次删除。
  6. trainer的参数keep_serialized_model_every_num_seconds默认为None,如果设置了一个整数(多少秒),那么在Checkpointer记录的列表中压入新数据,删除列表第一个数据时和前一个豁免删除的检查点的差值大于设定的这个时间,则这个数据不被删除(程序中List中记录的数据被删除)
  7. trainer的参数model_save_interval默认为None,如果设置了一个整数(多少秒),则训练中经过这么长时间保存一个检查点,每输入一个batch都进行一次时间的判断。
  8. best.th是什么时候得到的?每次保存检测点时,trainer从MetricTracker那里得知当前是不是最佳self._metric_tracker.is_best_so_far(),如果是最佳,则检查点将这个model_state.th复制一份,命名为 best.th shutil.copyfile(model_path, os.path.join(self._serialization_dir, "best.th"))

MetricTracker

MetricTracker类控制模型什么时候停止训练(比如过拟合的时候我们需要停止),在trainer中,传入它的是验证集输出的metrics(注意和训练集无关),其中有增加的accuracy,f1等,也有要减少的loss。
trainer参数——>MetricTracker:

  1. patience:当模型的性能经过多少个epoch没有增加时,停止训练。
  2. validation_metric:默认为"-loss",以loss为例,前缀“-”(accuracy就为"+")即我们需要它减少,如果增加了,说明性能变差了。如果 >= patience个epoch依然如此,则停止训练。

过程: 一轮训练+验证结束后,验证集的metrics传入metric_tracker。
如果loss值增加了,则self._epochs_with_no_improvement += 1self._epochs_with_no_improvement >= self._patience时,通知trainer要break出epoch的循环了。
如果loss值依然减少,更新最佳值self._best_so_far,self._epochs_with_no_improvement清零,顺便self.best_epoch = self._epoch_number记录下哪一个epoch的验证集loss值最小。

val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True)

 # Check validation metric for early stopping
this_epoch_val_metric = val_metrics[self._validation_metric]  
#_validation_metric默认为"loss",则只将loss值传入_metric_tracker进行判断
self._metric_tracker.add_metric(this_epoch_val_metric)

正则项

在model的构造函数中,参数有词典和正则项。

def __init__(self,
                 vocab: Vocabulary,
                 regularizer: RegularizerApplicator = None)

trainer将一个batch的数据送进modelself.batch_loss()之后,得到batch的平均loss值,(如果是训练)加一个正则项。

loss = output_dict["loss"]
if for_training:
      loss += self.model.get_regularization_penalty()

你可能感兴趣的:(allennlp,学习AllenNLP)