pytorch-lightning遇到ModelCheckpoint不能保存模型

目前实现的ModelCheckpoint如下

  checkpoint_callback = ModelCheckpoint(os.path.join(f'runs_fine_tuning/{args.expname}/ckpts/'),filename='{epoch:02d}',
                                          monitor='train/PSNR',
                                          mode='max',
                                          #save_top_k=-1,
                                          save_top_k=2
                                          #every_n_train_steps=100
                                          #every_n_val_epochs =1
                                          save_last=True)
 trainer = Trainer(checkpoint_callback=checkpoint_callback,
                   val_check_interval=500,)#就是trian中每500个step验证一次val。很方便

以上代码就是不能保存模型,很奇怪,其中#后面的各种修改也不能实现save model的作用。最后不得不手动保存。也挺简单的 ##是mvsNerf代码

    def validation_epoch_end(self, outputs):
        mean_psnr_all = torch.stack([x['val_psnr_all'] for x in outputs]).mean()
        self.log('val/PSNR_all', mean_psnr_all, prog_bar=True)
        self.save_ckpt(f'{self.global_step}')#就是每次val结束后,都会保存一次模型,其中是每step=500的时候会val一次,然后就可以保存模型了。。。
        return


    def save_ckpt(self, name='latest'):
        save_dir = f'runs_fine_tuning/{self.args.expname}/ckpts/'
        os.makedirs(save_dir, exist_ok=True)
        path = f'{save_dir}/{name}.tar'
        ckpt = {
            'global_step': self.global_step,
            'network_fn_state_dict': self.render_kwargs_train['network_fn'].state_dict(),
            'volume': self.volume.state_dict(),
            'network_mvs_state_dict': self.MVSNet.state_dict()}
        if self.render_kwargs_train['network_fine'] is not None:
            ckpt['network_fine_state_dict'] =self.render_kwargs_train['network_fine'].state_dict()
        torch.save(ckpt, path)
        print('Saved checkpoints at', path)

你可能感兴趣的:(pytorch,深度学习,python)