在这个项目中,为了实现两个功能:train、test,需要三大部分:model、data、loss,使用训练器trainer将这些部分进行连接训练。下面对src文件夹中的各部分进行简单介绍。
main.py
是所有程序的入口。它导入了所有内部的程序,并且接收一些外部指令中的参数,由argparser进行解析,传入各个模块,进行配置。
可以看到,loader/model/loss分别在整个工程中的三个其他文件夹(库的形式)定义。而trainer包含了训练和测试两个功能,在同一目录train.py内定义。
videotester.py
用于一帧一帧超分视频with torch.set_grad_enabled(True/False): 相当于 with torch.no_grad() 用于测试时取消梯度计算,并且配上了model.eval()
如果超分后视频不保存为avi格式,需要将33、34行进行修改。
cv2.VideoWriter_fourcc(*'XVID')的意思:OpenCV - VideoWriter_fourcc - ZhangZhihuiAAA - 博客园
例如如果想保存为.mp4格式,那就应写为cv2.VideoWriter_fourcc('X','2','6','4') #即X264编码
option.py
通过argparser来对命令行中的参数进行导入,其实大多数都已经默认好了,但是给用户调整的空间。
argparser以dict的形式返回一个namespace,稍后就可以用访问字典的方式访问里面的变量。
template.py
因为option.py中使用template.set_template(args)来对args(namespace)进行快捷设置,因此,如果需要在这个project中加入自己的新模型,那么要在这个文件中加入并定义,这样可以省下一些参数在命令行中输入的操作。
dataloader.py
在PyTorch新版本中已经提供相应功能,因此这个文件已经废弃。
utility.py
1、timer类
就像秒表一样,有开始tic,记录时间toc,保持hold(多次hold就将这些并行的用时数据进行累加),重置reset和释放release||主要在trainer配合使用
2、checkpoint类
- 第一部分中checkpoint提供了断点续训。配置load参数即可。它将所有的log保存在psnr_log.pt文件中,然后通过读取该文件判断断点位置来继续训练。当然,如何在self.log(这是个tensor数组)写log(每一轮测试psnr值)在checkpoint类中没有写,可能在trainer中会对其进行操作。这个断点续训,好像没有模型pt文件导入的流程,可能需要在命令参数中传入。
- 第二部分主要是将config、log进行记录(txt文件)。
- 第三部分主要是一些函数的定义,对多进程保存图像、psnr绘图等功能进行了支持
3、calc_psnr函数,其中有将图像转化为YCbCr的Y通道
gray_coeffs = [65.738, 129.057, 25.064] #转YCbCr:Y’= 0.257*R'+0.504*G'+0.098*B'+16 #16做差后抵消
4、make_optimizer构建了带 多步长衰减 的优化器。作者用CustomOptimizer来继承Optimizer类,使得“make optimizer and scheduler together”。当然,也可以类似于这么写
optimizer_MultiStepLR = torch.optim.SGD(net.parameters(), lr=0.1) print(optimizer_MultiStepLR) a = torch.optim.lr_scheduler.MultiStepLR(optimizer_MultiStepLR, milestones=[200, 300, 320, 340, 400], gamma=0.8)
当然,作者这样更加“面向对象”,生成一个优化器,再将网络参数传入,返回一个已经配置好的优化器。上面这种写法,则是先构建一个优化器,然后再去配置。
trainer.py
提供了训练器,用来支持整个训练过程。
- init中,导入了数据、模型、损失函数和优化器函数,如果是断点续训,则导入优化器的参数。
- 在train训练过程中,比较常规的过程,提供了梯度裁剪
- prepare封装了数据导入GPU的过程,terminate作为根据epoch判断是否结束训练,在main.py中使用。
(self.ckp.log是咋操作的?此问题待解决)
__init__.py
内置了两个类。
MyConcatDataset是对ConCatDataset类进行的封装,用于连接两个类。默认是DIV2K用于训练,如果在此基础上添加Flickr2K用于训练,相当于对此进行了可拓展的支持。
Data类则是提供了data_train和data_test,在前面的trainer.py中可以直接进行导入。
common.py
提供了4个函数,分别用于:从图像中随意取图像小块用于训练 ,增减图像的通道数至训练要求,数组转为Tensor类型,随机对图像进行增强。这些都在srdata类中被调用。
srdata.py
对PyTorch的data.Dastaset类进行继承,并基于符合超分训练的范式增加相应的功能,建立了数据的导入过程。后面的div2k.py和benchmark.py为重点关注的对象,他们继承了SRData,div2k主要是在输入参数的时候有1~800/800-810,分别用于训练和验证。benchmark主要是对文件导入的路径进行设置。
__init__.py
继承了nn.modules.loss._Loss类,对MSE、L1、VGG loss和GAN loss提供支持。
discriminator.py
提供了discriminator判别器,最后输出为一个数,判断这个数与0接近还是1接近。
adversial.py
提供了GAN、WGAN、RaGAN的功能。
vgg.py
提供了Perceptual loss的 功能。从pytorch官方model zoo中下载vgg19并导入需要的前面几层用于推理。
vgg_features = models.vgg19(pretrained=True).features modules = [m for m in vgg_features] if conv_index.find('22') >= 0: self.vgg = nn.Sequential(*modules[:8]) elif conv_index.find('54') >= 0: self.vgg = nn.Sequential(*modules[:35])
__init__.py
Model类,__init__提供了一些基本参数的配置,forward提供了多卡训练,测试时是否裁剪最后拼接,是否进行self-esemble(forward x8)每张图做各个方向共8次,最后进行平均
forward chop和x8比较混乱,有待进一步整理
common.py对各模块进行封装,方便后续进行调用,而其他文件则是对各网络结构的定义。
求一个小小的赞,祝大家虎年万事如意
欢迎在评论区一起交流讨论(*^▽^*)