PyTorch Lightning 中的批量测试及其存在的问题

2022-1-5, Wed., 13:37 于鸢尾花基地
可以采用如下方式对之前保存的预训练模型进行批量测试:

for ckpt in ckpt_list:
    model = ptl_module.load_from_checkpoint(ckpt, args=args)
    trainer.test(model, dataloaders=test_dataloader)

然而,在上述循环中,通过trainer.test每执行一次测试,都只是执行了一个epoch的测试(也就是执行多次ptl_module.test_step和一次ptl_module.test_epoch_end),而不可能把ckpt_list中的多个预训练模型(checkpoint)当做多个epoch,多次执行ptl_module.test_epoch_end

我们期望,对多个checkpoint的测试能像对多个epoch的训练一样简洁:

trainer.test(ptl_module, dataloaders=test_dataloader)

怎么做到?在训练过程中,要训练多少个epoch是由参数max_epochs来决定的;而在测试过程中,怎么办?PTL并非完整地保存了所有epoch的预训练模型。

由于在测试过程中对各checkpoint是独立测试的,如果要统计多个checkpoint的最优性能(如最大PSNR/SSIM),怎么办?这里的一个关键问题是如何保存每次测试得到的评估结果,好像PTL并未对此提供接口。

解决方案
PTL提供了“回调类(Callback)”(在 pytorch_lightning.callbacks 中),可以自定义一个回调类,并重载on_test_epoch_end方法,来监听ptl_module.test_epoch_end
如何使用?只需要在定义trainer时,把该自定义的回调函数加入其参数callbacks即可:ptl.Trainer(callbacks=[MetricTracker()])。这里,MetricTracker为自定义的回调类,具体如下:

class MetricTracker(Callback):

    def __init__(self):
        self.optim_metrics = None

    def on_test_epoch_end(self, trainer, pl_module):
        if self.optim_metrics is None:
            self.optim_metrics = pl_module.metrics_dict
            return

        tensorboard = pl_module.logger.experiment
        metrics_key_list, metrics_val_list = [], []
        for k in pl_module.metrics_dict:
            # comp_fun 是自己定义的比较函数
            self.optim_metrics[k] = comp_fun(self.optim_metrics[k], pl_module.metrics_dict[k])

评论: 由于MetricTracker具有与Trainer相同的生命周期,因此,在整个测试过程中,MetricTracker能够维护一个最优的评估结果optim_metrics

你可能感兴趣的:(PyTorch Lightning 中的批量测试及其存在的问题)