Pytorch Lightning系列 如何使用ModelCheckpoint

在训练机器学习模型时,经常需要缓存模型。ModelCheckpoint是Pytorch Lightning中的一个Callback,它就是用于模型缓存的。它会监视某个指标,每次指标达到最好的时候,它就缓存当前模型。Pytorch Lightning文档 介绍了ModelCheckpoint的详细信息。

我们来看几个有趣的使用示例。

示例1 注意,我们把epoch和val_loss信息也加入了模型名称。

>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val_loss', #我们想要监视的指标 
...     dirpath='my/path/',  #模型缓存目录
...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' # 模型名称
... )

示例2 这个使用例子非常像示例1,唯一的差别在于指标的名称是由我们自己指定的,而不是由Pytorch Lightning自动生成的 (auto_insert_metric_name=False)。通过这样的方式,我们可以使用类似val/mrr的指标名。从而统一tensorboard和pytorch lightning对指标的不同描述方式。

>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val/loss',
...     dirpath='my/path/',
...     filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', # 注意到val/loss变成了val_loss
...     auto_insert_metric_name=False
... )

Pytorch Lightning把ModelCheckpoint当作最后一个CallBack,也就是它总是在最后执行。这一点在我看来很别扭。如果你在训练过程中想获得best_model_score或者best_model_path,它对应的是上一次模型缓存的结果,而并不是最新的模型缓存结果

self.trainer.checkpoint_callback.best_model_score

你可能感兴趣的:(Pytorch Lightning系列 如何使用ModelCheckpoint)