SRCNN-pytoch代码讲解

pytorch版本的SRCNN代码一共分为6个.py文件,结构如下:

  • datasets.py
  • models.py
  • prepare.py
  • utils.py
  • test.py
  • train.py

  以上文件不分先后,执行时通过import…或者from…import…语句进行调用。以下解释import部分均省略,个别例外。

prepare.py

  readme.md中给出了不同放大倍数下的训练数据,验证数据和测试数据的下载地址。如果下载了直接把对应的路径写好就可以执行了,这里我们使用自己下载的数据通过使用prepare.py来制作训练和验证的h5格式的数据集。

import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import convert_rgb_to_y

#该函数用来创建自己的h5数据,包括俩个函数:对训练数据的处理和验证部分的处理。
def train(args):
    h5_file = h5py.File(args.output_path, 'w')
    '''
    def是python的关键字,用来定义函数。这里通过def定义名为train的函数,函数的参数为args,args这个参数通过外部命令行传入output
    的路径,通过h5py.File()方法的w模式--创建文件自己自写,已经存在的文件会被覆盖,文件的路径是通过args.output_path来传入
    '''
    lr_patches = []
    hr_patches = []
    '''
    创建俩个空列表:lr_patches和hr_patches(通过ctrl左键该变量名查看在其他位置的引用)
    '''
    for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
        '''
        这部分代码的目的就是搜索指定文件夹下的文件并排序,for这一句包含了几个知识点:
        1.{}.format():-->格式化输出函数,从args.images_dir路径中格式化输出路径
        2.glob.glob():-->返回所有匹配的文件路径列表,将1得到的路径中的所有文件返回
        3.sorted():-->排序,将2得到的所有文件按照某种顺序返回,,默认是升序
        4.for x in *:   -->循换输出
        '''
        hr = pil_image.open(image_path).convert('RGB')
        '''
        1.***.open():是PIL图像库的函数,用来从image_path中加载图像
        2.***.convert():是PIL图像库的函数,用来转换图像的模式
        '''
        hr_width = (hr.width // args.scale) * args.scale
        hr_height = (hr.height // args.scale) * args.scale
        hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
        lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
        lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)#缩放处理
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = convert_rgb_to_y(hr)
        lr = convert_rgb_to_y(lr)
        '''
        " / "  表示浮点数除法,返回浮点结果;" // " 表示整数除法,返回不大于结果的一个最大的整数,也就是向下取整
        这里的hr是输入的原图,先进行mod和缩放的预处理,lr是hr在mod之后经过scale的结果,得到的lr再经过缩放处理得到最终要用的lr的图片
        resize():缩放操作
        np.array():将列表list或元组tuple转换为ndarray数组
        astype():转换数组的数据类型
        convert_rgb_to_y():将图像从RGB格式转换为Y通道格式的图片
        假设原始输入图像为(321,481,3)-->依次为高,宽,通道数
        1.先mod,之后hr的图像尺寸为(320,480,3)
        2.对hr图像进行双三次上采样放大操作
        3.将hr//scale进行双三次上采样放大操作之后×scale得到lr
        4.接着进行通道数转换和类型转换
        '''
        for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
            '''
            图像的shape是宽度、高度和通道数,shape[0]是指图像的高度=320;shape[1]是图像的宽度=480;
            shape[2]是指图像的通道数
            '''
            for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
                lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
                hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])

    lr_patches = np.array(lr_patches)
    hr_patches = np.array(hr_patches)
    #把得到的数据转化为数组类型
    h5_file.create_dataset('lr', data=lr_patches)
    h5_file.create_dataset('hr', data=hr_patches)

    h5_file.close()


def eval(args):
    h5_file = h5py.File(args.output_path, 'w')

    lr_group = h5_file.create_group('lr')
    hr_group = h5_file.create_group('hr')

    for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
        hr = pil_image.open(image_path).convert('RGB')
        hr_width = (hr.width // args.scale) * args.scale
        hr_height = (hr.height // args.scale) * args.scale
        hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
        lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
        lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = convert_rgb_to_y(hr)
        lr = convert_rgb_to_y(lr)

        lr_group.create_dataset(str(i), data=lr)
        hr_group.create_dataset(str(i), data=hr)

    h5_file.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--images-dir', type=str,default='/home/dushuai/word/SRCNN_pytorch/evaldata')
    parser.add_argument('--output-path', type=str,default='/home/dushuai/word/SRCNN_pytorch/evalout/evalout.h5')
    parser.add_argument('--patch-size', type=int, default=33)
    parser.add_argument('--stride', type=int, default=14)
    parser.add_argument('--scale', type=int, default=2)
    parser.add_argument('--eval', action='store_true')
    args = parser.parse_args()

    if not args.eval:
        train(args)
    else:
        eval(args)
'''
最后这个if..else..要注意一下,是和parser传入的最后一个参数有关的,它是用来决定使用哪个函数来生成h5文件,因为有俩个不同的函数train和eval生成对应的h5文件。该参数的具体使用方法如下
'''

实验:action

  在我看来这是个很鸡肋的参数设置,但是存在即合理,我们只需要明白它就ok了。

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_false')
args = parser.parse_args()
def main():
    x = args.eval
    print(x)

if __name__ == '__main__':
    main()

  可以看到我上边的action=‘store_false’,但是边一个是直接在IDE中run的结果是True,而我通过命令行运行得到的结果却是false,这是为什么?
SRCNN-pytoch代码讲解_第1张图片
SRCNN-pytoch代码讲解_第2张图片
  顾名思义,store_flase就是存储一个bool值false,也就是说在该参数在被激活时它会输出store存储的值也就是这里我通过命令行得到的值,而IDE得到的值没有激活该参数,得到的是它的默认值True.

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_false')
args = parser.parse_args()

def a():
    print('a')
def b():
    print('b')
def main():
    x = args.eval
    print(x)
    if not args.eval:
        print(args.eval)
        a()
    else:
        print(args.eval)
        b()
if __name__ == '__main__':
    main()

SRCNN-pytoch代码讲解_第3张图片
  在SRCNN的预处理中可以通过修改action中store的值也可以通过if not args.eval来调整函数运行哪个函数来得到对应的结果。

datasets.py

一共包含俩个类TrainDataset()和EvalDataset(),分别用来加载prepare.py制作的训练和验证俩个数据集的。这部分想自己写,但是发现了一篇不错的博客,传送门在此

models.py

这部分更为简单,首先定义了模型类SRCNN,它继承自父类nn.Module。super这句是对继承自父类的属性进行初始化。接下来就是对卷积层的定义和前向传播的定义。

utils.py

这个utils.py相当于是工具类,定义网络需要使用的各种函数。这个文件一共包括了四个函数和一个类,至于test和train都很简单,很容易看懂,略

参考文献:
1.if name == ‘main’:
2.Python之argparse

你可能感兴趣的:(Python,python,深度学习,pytorch,神经网络)