DeblurGAN-V2源代码解析

DeblurGAN-V2源代码解析(pytorch)

DeblurGAN-V2源代码解析_第1张图片
DeblurGAN-V2是DeblurGAN的改进版,主要解决的是去图像运动模糊的问题,相比于DeblurGAN而言有速度更快,效果更好的优点。

论文:https://arxiv.org/pdf/1908.03826.pdf
代码:https://github.com/TAMU-VITA/DeblurGANv2
博客讲解:https://blog.csdn.net/weixin_42784951/article/details/100168882

本文主要针对作者源代码进行总结,不足之处尽请提出。

1、全部文件

以下是从github上下载的全部文件,训练运行train.py文件,评价运行predict.py文件
DeblurGAN-V2源代码解析_第2张图片

2、整体结构

1、config文件是参数配置文件,主要设置了模型中所需要的各种参数;
2、models文件是模型文件,主要用于网络结构和网络模型的搭建;
3、util文件是图像处理文件,主要用于图像的基本处理,和SSIM、PSNR的实现
4、生成数据集主要由dataset.py、test_dataset.py用于进行数据集的生成;

3、train.py模型训练主文件

- 首先,看主程序:

if __name__ == '__main__':
	#1、读入配置文件
    with open('config/config.yaml', 'r') as f:
        config = yaml.load(f)

    batch_size = config.pop('batch_size')
    
    #partial(),python中的偏函数
    #2、得到原始数据
    get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=True)
    datasets = map(config.pop, ('train', 'val'))
    datasets = map(PairedDataset.from_config, datasets)
    train, val = map(get_dataloader, datasets)
    #3、实例化Trainer(),并且进行训练
    trainer = Trainer(config, train=train, val=val)
    trainer.train()

1、读入配置文件
2、得到原始数据
3、实例化Trainer(),并且进行训练trainer.train()。

- 其次,train()函数:

def train(self):
        self._init_params()
        for epoch in range(0, config['num_epochs']):
            if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
                self.netG.module.unfreeze()
                self.optimizer_G = self._get_optim(self.netG.parameters())
                self.scheduler_G = self._get_scheduler(self.optimizer_G)
            self._run_epoch(epoch)
            self._validate(epoch)
            self.scheduler_G.step()
            self.scheduler_D.step()

            if self.metric_counter.update_best_model():
                torch.save({
                    'model': self.netG.state_dict()
                }, 'best_{}.h5'.format(self.config['experiment_desc']))
            torch.save({
                'model': self.netG.state_dict()
            }, 'last_{}.h5'.format(self.config['experiment_desc']))
            print(self.metric_counter.loss_message())
            logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
                self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))

(未完待续)

你可能感兴趣的:(GAN生成对抗网络)