超分辨率 EDSR开源项目

在这个项目中,为了实现两个功能:train、test,需要三大部分:model、data、loss,使用训练器trainer将这些部分进行连接训练。下面对src文件夹中的各部分进行简单介绍。

一、.\

main.py


是所有程序的入口。它导入了所有内部的程序,并且接收一些外部指令中的参数,由argparser进行解析,传入各个模块,进行配置。

可以看到,loader/model/loss分别在整个工程中的三个其他文件夹(库的形式)定义。而trainer包含了训练和测试两个功能,在同一目录train.py内定义。

videotester.py
用于一帧一帧超分视频

with torch.set_grad_enabled(True/False): 相当于 with torch.no_grad() 用于测试时取消梯度计算,并且配上了model.eval()

如果超分后视频不保存为avi格式,需要将33、34行进行修改。

cv2.VideoWriter_fourcc(*'XVID')的意思:OpenCV - VideoWriter_fourcc - ZhangZhihuiAAA - 博客园

例如如果想保存为.mp4格式,那就应写为cv2.VideoWriter_fourcc('X','2','6','4') #即X264编码

 option.py


通过argparser来对命令行中的参数进行导入,其实大多数都已经默认好了,但是给用户调整的空间。

argparser以dict的形式返回一个namespace,稍后就可以用访问字典的方式访问里面的变量。

 template.py


因为option.py中使用template.set_template(args)来对args(namespace)进行快捷设置,因此,如果需要在这个project中加入自己的新模型,那么要在这个文件中加入并定义,这样可以省下一些参数在命令行中输入的操作。

dataloader.py


 在PyTorch新版本中已经提供相应功能,因此这个文件已经废弃。

utility.py


1、timer类

就像秒表一样,有开始tic,记录时间toc,保持hold(多次hold就将这些并行的用时数据进行累加),重置reset和释放release||主要在trainer配合使用

2、checkpoint类

  • 第一部分中checkpoint提供了断点续训。配置load参数即可。它将所有的log保存在psnr_log.pt文件中,然后通过读取该文件判断断点位置来继续训练。当然,如何在self.log(这是个tensor数组)写log(每一轮测试psnr值)在checkpoint类中没有写,可能在trainer中会对其进行操作。这个断点续训,好像没有模型pt文件导入的流程,可能需要在命令参数中传入。
  • 第二部分主要是将config、log进行记录(txt文件)。
  • 第三部分主要是一些函数的定义,对多进程保存图像、psnr绘图等功能进行了支持

3、calc_psnr函数,其中有将图像转化为YCbCr的Y通道

gray_coeffs = [65.738, 129.057, 25.064]
#转YCbCr:Y’= 0.257*R'+0.504*G'+0.098*B'+16
#16做差后抵消

4、make_optimizer构建了带 多步长衰减 的优化器。作者用CustomOptimizer来继承Optimizer类,使得“make optimizer and scheduler together”。当然,也可以类似于这么写

optimizer_MultiStepLR = torch.optim.SGD(net.parameters(), lr=0.1)
print(optimizer_MultiStepLR)
a = torch.optim.lr_scheduler.MultiStepLR(optimizer_MultiStepLR,
                    milestones=[200, 300, 320, 340, 400], gamma=0.8)

当然,作者这样更加“面向对象”,生成一个优化器,再将网络参数传入,返回一个已经配置好的优化器。上面这种写法,则是先构建一个优化器,然后再去配置。

trainer.py


提供了训练器,用来支持整个训练过程。 

  • init中,导入了数据、模型、损失函数和优化器函数,如果是断点续训,则导入优化器的参数。
  • 在train训练过程中,比较常规的过程,提供了梯度裁剪
  • prepare封装了数据导入GPU的过程,terminate作为根据epoch判断是否结束训练,在main.py中使用。

 (self.ckp.log是咋操作的?此问题待解决)

 二、.\data\

__init__.py


内置了两个类。

MyConcatDataset是对ConCatDataset类进行的封装,用于连接两个类。默认是DIV2K用于训练,如果在此基础上添加Flickr2K用于训练,相当于对此进行了可拓展的支持。

Data类则是提供了data_train和data_test,在前面的trainer.py中可以直接进行导入。

common.py


提供了4个函数,分别用于:从图像中随意取图像小块用于训练 ,增减图像的通道数至训练要求,数组转为Tensor类型,随机对图像进行增强。这些都在srdata类中被调用。

 srdata.py


对PyTorch的data.Dastaset类进行继承,并基于符合超分训练的范式增加相应的功能,建立了数据的导入过程。后面的div2k.py和benchmark.py为重点关注的对象,他们继承了SRData,div2k主要是在输入参数的时候有1~800/800-810,分别用于训练和验证。benchmark主要是对文件导入的路径进行设置。

三、.\loss\

__init__.py


继承了nn.modules.loss._Loss类,对MSE、L1、VGG loss和GAN loss提供支持。

discriminator.py


提供了discriminator判别器,最后输出为一个数,判断这个数与0接近还是1接近。

adversial.py


提供了GAN、WGAN、RaGAN的功能。 

vgg.py


提供了Perceptual loss的 功能。从pytorch官方model zoo中下载vgg19并导入需要的前面几层用于推理。

vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index.find('22') >= 0:
   self.vgg = nn.Sequential(*modules[:8])
elif conv_index.find('54') >= 0:
   self.vgg = nn.Sequential(*modules[:35])

四、 .\model\

__init__.py


Model类,__init__提供了一些基本参数的配置,forward提供了多卡训练,测试时是否裁剪最后拼接,是否进行self-esemble(forward x8)每张图做各个方向共8次,最后进行平均

 forward chop和x8比较混乱,有待进一步整理

common.py对各模块进行封装,方便后续进行调用,而其他文件则是对各网络结构的定义。

求一个小小的赞,祝大家虎年万事如意

欢迎在评论区一起交流讨论(*^▽^*)

你可能感兴趣的:(和火炬PY,超分,pytorch)