pytorch-lightning 设置为每个训练 epoch 结束都保存 checkpoint

pytorch-lightning 设置为每个训练 epoch 结束都保存 checkpoint

ModelCheckpoint
pytorch-lightning 中保存断点用的是回调函数 ModelCheckpoint,并且必须是在验证循环结束后才会保存。这样的方式不适合于一些特殊任务,例如 Transformer 结构下的语音识别模型一般都需要 average 最后10-20 个 epoch 的模型权重。而且对于自回归模型来说进行一次真正的(即不提供真实标签)验证需要的时间较长,实际上整个训练过程中可能并没有验证过程。

SaveCheckpoint
重写 ModelCheckpoint,实现每个训练 epoch 结束都保存 checkpoint

class SaveCheckpoint(ModelCheckpoint):
    """save checkpoint after each training epoch without validation.
    if ``last_k == -1``, all models are saved. and no monitor needed in this condition.
    otherwise, please log ``global_step`` in the training_step. e.g. self.log('global_step', self.global_step)

    :param last_k: the latest k models will be saved.
    :param save_weights_only: if ``True``, only the model's weights will be saved,
    else the full model is saved.
    """
    def __init__(self, last_k=5, save_weights_only=False):
        if last_k == -1:
            super().__init__(save_top_k=-1, save_last=False, save_weights_only=save_weights_only)
        else:
            super().__init__(monitor='global_step', mode='max', save_top_k=last_k,
                             save_last=False, save_weights_only=save_weights_only)

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        """
        save checkpoint after each train epoch
        """
        self.save_checkpoint(trainer, pl_module)

    def on_validation_end(self, trainer, pl_module):
        """
        overwrite the methods in ModelCheckpoint to avoid save checkpoint on the end of the val loop
        """
        pass

你可能感兴趣的:(神经网络,pytorch,python)