通过监控数量来定期保存模型。
记录的每个指标:
meth:`~pytorch_lightning.core.lightning.log` or meth:`~pytorch_lightning.core.lightning.log_dict`
LightningModule是监视器键的候选。
训练结束后,使用:attr: ' best_model_path '来检索最佳检查点文件的路径,使用:attr: ' best_model_score '来检索它的分数。
def __init__(
self,
dirpath: Optional[_PATH] = None,
filename: Optional[str] = None,
monitor: Optional[str] = None,
verbose: bool = False,
save_last: Optional[bool] = None,
save_top_k: int = 1,
save_weights_only: bool = False,
mode: str = "min",
auto_insert_metric_name: bool = True,
every_n_train_steps: Optional[int] = None,
train_time_interval: Optional[timedelta] = None,
every_n_epochs: Optional[int] = None,
save_on_train_epoch_end: Optional[bool] = None,
):
参数解释:
Dirpath:模型文件存放的路径;(默认情况下,dirpath为' ' None ' ',并将在运行时设置为:class: ' ~pytorch_lightning.trainer.trainer.Trainer' 指定的位置。)
# custom path
# saves a file like: my/path/epoch=0-step=10.ckpt
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
filename:检查点文件名。可以包含自动填充的命名格式选项。
# save any arbitrary metrics like `val_loss`, etc. in name
# saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
>>> checkpoint_callback = ModelCheckpoint(
... dirpath='my/path',
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
monitor: 量监控。默认情况下,它是' ' None ' ',它只保存检查点的最后一个时期;
verbose: verbosity mode. Default: ``False``;
save_last:当' ' True ' '时,将检查点的精确副本保存到文件' last '。Ckpt '每当检查点文件被保存。这允许以确定的方式访问最新的检查点。默认值:' '没有' '
Save_top_k:如果' ' Save_top_k == k ',根据监控的数量,保存最好的k。如果' ' save_top_k == 0 ' ',则不保存任何模型。如果' ' save_top_k == -1 ' ',则保存所有模型。
save_weights_only:如果' ' True ' ',那么只有模型的权重会被保存。否则,优化器状态、lr调度器状态等也会添加到检查点中。
every_n_train_steps:检查点之间的训练步骤数;
every_n_epoch:检查点之间epoch的数量;
save_on_train_epoch_end:是否在训练epoch结束时运行检查点。如果这是' ' False ' ',那么检查将在验证结束时运行。