EDSR代码阅读

https://github.com/thstkdgus35/EDSR-PyTorch

EDSR源代码提供了一个完整的训练框架,值得借鉴参考。

一、\src

__initial__.py:空文件,如果要作为一个包,必须有initial。

option.py:利用argparse包,对demo里面的脚本命令进行支持,建立arg类里面的参数。

template.py:快速补充arg参数。

dataloader.py:PyTorch提供的。

trainer.py:里面提供了一个Trainer类,提供了 train、test、prepare、terminate四个方法。值得注意的是:train方法下提供了梯度裁剪功能,高于阈值的梯度将直接设置为阈值;prepare方法提供了将tensor设置为半精度的方法。

        https://www.cnblogs.com/icodeworld/p/11882263.html
        def _prepare(tensor):
            if self.args.precision == 'half': tensor = tensor.half()
            return tensor.to(device)

utility.py:timer()类提供关于时间的功能。checkpoint类提供了关于模型的保存、训练日志记录、使用多线程来保存训练保存测试结果图片的功能。make_optim函数提供优化器和学习率到达一定的数值后衰减的功能。

import torch.optim.lr_scheduler as lrs
scheduler_class = lrs.MultiStepLR

VideoTester.py:提供视频超分辨率测试的功能。使用了openCV库

二、\src\data

__initial__.py:MyConcatDataset定义了连接不同的数据集;Class Data提供了train和test的dataloader。

srdata.py: 提供了关于图像数据的导入路径、图像转为二进制文件等功能。glob库用于查找文件;pickle库用于导入、生成二进制文件。

benchmark.py: 继承了srdata类。

common.py: get_patch函数获得由DIV2K中的图取一部分为小图,以及对应高分辨率图像中的那一部分,这两个对应的部分输入网络进行训练;set_channel函数将通道数转为图像训练要求的数值;np2tensor将图片转换为PyTorch的tensor类型;augment提供了将图像翻折、旋转的数据增强方式。

demo.py: demo类继承了dataloader,程序里面并没有用处。

其他文件提供对应数据集的训练/测试。

三、\src\loss

提供了损失函数的组合、损失的绘图、保存进度等功能。

四、\src\model

提供了几个模型的程序。init中还提供了分割图像的功能,使得在运存有限的情况下,可以将图片分割为小块以后,超分后合并输出。

你可能感兴趣的:(EDSR代码阅读)