pytorch版本的SRCNN代码一共分为6个.py文件,结构如下:
以上文件不分先后,执行时通过import…或者from…import…语句进行调用。以下解释import部分均省略,个别例外。
readme.md中给出了不同放大倍数下的训练数据,验证数据和测试数据的下载地址。如果下载了直接把对应的路径写好就可以执行了,这里我们使用自己下载的数据通过使用prepare.py来制作训练和验证的h5格式的数据集。
import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import convert_rgb_to_y
#该函数用来创建自己的h5数据,包括俩个函数:对训练数据的处理和验证部分的处理。
def train(args):
h5_file = h5py.File(args.output_path, 'w')
'''
def是python的关键字,用来定义函数。这里通过def定义名为train的函数,函数的参数为args,args这个参数通过外部命令行传入output
的路径,通过h5py.File()方法的w模式--创建文件自己自写,已经存在的文件会被覆盖,文件的路径是通过args.output_path来传入
'''
lr_patches = []
hr_patches = []
'''
创建俩个空列表:lr_patches和hr_patches(通过ctrl左键该变量名查看在其他位置的引用)
'''
for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
'''
这部分代码的目的就是搜索指定文件夹下的文件并排序,for这一句包含了几个知识点:
1.{}.format():-->格式化输出函数,从args.images_dir路径中格式化输出路径
2.glob.glob():-->返回所有匹配的文件路径列表,将1得到的路径中的所有文件返回
3.sorted():-->排序,将2得到的所有文件按照某种顺序返回,,默认是升序
4.for x in *: -->循换输出
'''
hr = pil_image.open(image_path).convert('RGB')
'''
1.***.open():是PIL图像库的函数,用来从image_path中加载图像
2.***.convert():是PIL图像库的函数,用来转换图像的模式
'''
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)#缩放处理
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
'''
" / " 表示浮点数除法,返回浮点结果;" // " 表示整数除法,返回不大于结果的一个最大的整数,也就是向下取整
这里的hr是输入的原图,先进行mod和缩放的预处理,lr是hr在mod之后经过scale的结果,得到的lr再经过缩放处理得到最终要用的lr的图片
resize():缩放操作
np.array():将列表list或元组tuple转换为ndarray数组
astype():转换数组的数据类型
convert_rgb_to_y():将图像从RGB格式转换为Y通道格式的图片
假设原始输入图像为(321,481,3)-->依次为高,宽,通道数
1.先mod,之后hr的图像尺寸为(320,480,3)
2.对hr图像进行双三次上采样放大操作
3.将hr//scale进行双三次上采样放大操作之后×scale得到lr
4.接着进行通道数转换和类型转换
'''
for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
'''
图像的shape是宽度、高度和通道数,shape[0]是指图像的高度=320;shape[1]是图像的宽度=480;
shape[2]是指图像的通道数
'''
for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])
lr_patches = np.array(lr_patches)
hr_patches = np.array(hr_patches)
#把得到的数据转化为数组类型
h5_file.create_dataset('lr', data=lr_patches)
h5_file.create_dataset('hr', data=hr_patches)
h5_file.close()
def eval(args):
h5_file = h5py.File(args.output_path, 'w')
lr_group = h5_file.create_group('lr')
hr_group = h5_file.create_group('hr')
for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
hr = pil_image.open(image_path).convert('RGB')
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
lr_group.create_dataset(str(i), data=lr)
hr_group.create_dataset(str(i), data=hr)
h5_file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--images-dir', type=str,default='/home/dushuai/word/SRCNN_pytorch/evaldata')
parser.add_argument('--output-path', type=str,default='/home/dushuai/word/SRCNN_pytorch/evalout/evalout.h5')
parser.add_argument('--patch-size', type=int, default=33)
parser.add_argument('--stride', type=int, default=14)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
if not args.eval:
train(args)
else:
eval(args)
'''
最后这个if..else..要注意一下,是和parser传入的最后一个参数有关的,它是用来决定使用哪个函数来生成h5文件,因为有俩个不同的函数train和eval生成对应的h5文件。该参数的具体使用方法如下
'''
在我看来这是个很鸡肋的参数设置,但是存在即合理,我们只需要明白它就ok了。
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_false')
args = parser.parse_args()
def main():
x = args.eval
print(x)
if __name__ == '__main__':
main()
可以看到我上边的action=‘store_false’,但是边一个是直接在IDE中run的结果是True,而我通过命令行运行得到的结果却是false,这是为什么?
顾名思义,store_flase就是存储一个bool值false,也就是说在该参数在被激活时它会输出store存储的值也就是这里我通过命令行得到的值,而IDE得到的值没有激活该参数,得到的是它的默认值True.
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_false')
args = parser.parse_args()
def a():
print('a')
def b():
print('b')
def main():
x = args.eval
print(x)
if not args.eval:
print(args.eval)
a()
else:
print(args.eval)
b()
if __name__ == '__main__':
main()
在SRCNN的预处理中可以通过修改action中store的值也可以通过if not args.eval来调整函数运行哪个函数来得到对应的结果。
一共包含俩个类TrainDataset()和EvalDataset(),分别用来加载prepare.py制作的训练和验证俩个数据集的。这部分想自己写,但是发现了一篇不错的博客,传送门在此
这部分更为简单,首先定义了模型类SRCNN,它继承自父类nn.Module。super这句是对继承自父类的属性进行初始化。接下来就是对卷积层的定义和前向传播的定义。
这个utils.py相当于是工具类,定义网络需要使用的各种函数。这个文件一共包括了四个函数和一个类,至于test和train都很简单,很容易看懂,略
参考文献:
1.if name == ‘main’:
2.Python之argparse