DnCNN代码学习—main_testpy

                                   DnCNN代码学习—main_testpy

一、源代码+注释

# -*- coding: utf-8 -*-

# =============================================================================
#  @article{zhang2017beyond,
#    title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
#    author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
#    journal={IEEE Transactions on Image Processing},
#    year={2017},
#    volume={26}, 
#    number={7}, 
#    pages={3142-3155}, 
#  }
# by Kai Zhang (08/2018)
# [email protected]
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# run this to test the model

import argparse   ## python的参数解析argparse模块
import os, time, datetime    # #文件名操作模块glob     # time, datetime 时间模块
# import PIL.Image as Image
import numpy as np  #导入numpy
import torch.nn as nn  #torch.nn的核心数据结构是Module
import torch.nn.init as init  #初始化
import torch   #包 torch 包含了多维张量的数据结构以及基于其上的多种数学操作。
from skimage.measure import compare_psnr, compare_ssim  #计算图像的峰值信噪比(PSNR)  #计算两幅图像之间的平均结构相似性指数。
from skimage.io import imread, imsave  #在python中,图像处理主要采用的库:skimage, opencv-python, Pillow (PIL)。 这三个库均提供了图像读取的方法。


#创建解析器
def parse_args():
    parser = argparse.ArgumentParser()
    #增加参数
    # 测试集  data/Test
    parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
    #测试集名字  set68 set12
    parser.add_argument('--set_names', default=['Set68', 'Set12'], help='directory of test dataset')
    #噪声水平  默认25
    parser.add_argument('--sigma', default=25, type=int, help='noise level')
    #模型位置  models/DnCNN_sigma25     
    parser.add_argument('--model_dir', default=os.path.join('models', 'DnCNN_sigma25'), help='directory of the model')
    #模型名字 默认model_001.pth
    parser.add_argument('--model_name', default='model_001.pth', type=str, help='the model name')
    #结果位置  results
    parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
    #保存结果 1保存 0否
    parser.add_argument('--save_result', default=0, type=int, help='save the denoised image, 1 or 0')
    return parser.parse_args()  #解析参数并返回

##strftime()方法使用日期,时间或日期时间对象返回表示日期和时间的字符串  日志
def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)

#保存结果
def save_result(result, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        imsave(path, np.clip(result, 0, 1))

#展示图片
def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)   #    图像的长和宽(英寸) #Figure返回的实例也将传递给后端的new_figure_manage
    plt.imshow(x, interpolation='nearest', cmap='gray')  #interpolation 插值方法  #cmap: 颜色图谱(colormap), 默认绘制为RGB(A)颜色空间
    if title:
        plt.title(title)   #标题
    if cbar:
        plt.colorbar()  #将颜色条添加到绘图中。
    plt.show()


class DnCNN(nn.Module):

    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []
        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)


if __name__ == '__main__':

    #调用解析器
    args = parse_args()

    # model = DnCNN()
    #如果路径下的模型不存在models/DnCNN_sigma25/model_001.pth
    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):
        #加载路径下的模型models/DnCNN_sigma25/model.pth
        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        #加载源代码中自带的训练模型
        log('load trained model on Train400 dataset by kai')
    else:
        # model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        #加载自己训练的模型
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
#    print(params.keys())
#
#    for key, value in params.items():
#        print(key)    # parameter name
#    print(params['dncnn.12.running_mean'])
#    print(model.state_dict())
    
    
    #测试模型时会在前面加上
    model.eval()  # evaluation mode
#    model.train()
     
    #判断GPU是否可用    可用的话,model放在GPU上测试
    if torch.cuda.is_available():
        model = model.cuda()
     
    #如果存放结果文件夹不存在,创建文件夹
    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)

    for set_cur in args.set_names:

        #如果存放结果的文件夹不存在,创建Test/set12   Test/set68
        if not os.path.exists(os.path.join(args.result_dir, set_cur)):
            os.mkdir(os.path.join(args.result_dir, set_cur))
       
        psnrs = []  #峰值信噪比
        ssims = []  #结构相似性
        #args.set_dir:data/Test
        #os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
        for im in os.listdir(os.path.join(args.set_dir, set_cur)):
            #endswith() 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。
            if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
                #np.array将读入的图像转换为ndarray形式,且使像素值处于[0 1]
                x = np.array(imread(os.path.join(args.set_dir, set_cur, im)), dtype=np.float32)/255.0
                #seed( ) 用于指定随机数生成时所用算法开始的整数值,如果使用相同的seed( )值,则每次生成的随即数都相同,
                np.random.seed(seed=0)  #  随机生成一个种子  for reproducibility
                #numpy.random.normal函数,有三个参数(loc, scale, size),分别l代表生成的高斯分布的随机数的均值、方差以及输出的size.
                y = x + np.random.normal(0, args.sigma/255.0, x.shape)  # Add Gaussian noise without clipping
                #dtype 用于查看数据类型   astype 用于转换数据类型
                y = y.astype(np.float32)
                #将numpy矩阵y转化为torch张量y_,共享内存
                #np.reshape()和torch.view()效果一样,reshape()操作nparray,view()操作tensor
                y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])
                #正确的测试时间的代码 torch.cuda.synchronize()
                torch.cuda.synchronize()
                #开始时间
                start_time = time.time()
                #y_放在GPU
                y_ = y_.cuda()
                #调用模型
                x_ = model(y_)  # inference
                #把数据二维y.shape[0]* y.shape[1]
                x_ = x_.view(y.shape[0], y.shape[1])
                #若从gpu –> cpu,则使用data.cpu()
                x_ = x_.cpu()
                #detach()返回一个新的 从当前图中分离的 Variable,转换为numpy,然后类型转换
                x_ = x_.detach().numpy().astype(np.float32)
                #  正确的测试时间的代码 
                torch.cuda.synchronize()
                #运行时间
                elapsed_time = time.time() - start_time
                #输出测试集名字 :图片名  :  运行时间
                print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))

                psnr_x_ = compare_psnr(x, x_)   #skimage.measure库里的compare_psnr方法
                ssim_x_ = compare_ssim(x, x_)   #skimage.measure库里的compare_ssim方法
                
                #如果存放结果的文件夹存在,那么。。。
                if args.save_result:
                    #分离文件名与扩展名
                    name, ext = os.path.splitext(im)
                    #调用自定义函数show()展示图片
                    #numpy.hstack()函数是将数组沿水平方向堆叠起来
                    show(np.hstack((y, x_)))  # show the image
                    #调用保存结果函数  参数(X_,  路径results/set12/文件名_dncnn.后缀)
                    save_result(x_, path=os.path.join(args.result_dir, set_cur, name+'_dncnn'+ext))  # save the denoised image
                psnrs.append(psnr_x_)  #保存PSNR
                ssims.append(ssim_x_)  #保存SSIM
        #numpy.mean() 函数返回数组中元素的算术平均值
        psnr_avg = np.mean(psnrs)
        ssim_avg = np.mean(ssims)
        #append() 方法用于在列表末尾添加新的对象。
        psnrs.append(psnr_avg)
        ssims.append(ssim_avg)
        #如果文件夹存在
        if args.save_result:
            #保存(psnr和ssims值,放在文件results.txt里面)
            save_result(np.hstack((psnrs, ssims)), path=os.path.join(args.result_dir, set_cur, 'results.txt'))
        #打印日志,调用log函数
        log('Datset: {0:10s} \n  PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))

二、查找资料链接

argparse --- 命令行选项、参数和子命令解析器

re --- 正则表达式操作 — Python 3.7.4 文档

Python正则表达式指南 - AstralWind - 博客园

Python 标准库——os、glob模块 - 温柔一cai刀 - CSDN博客

python glob模块 - 火星大熊猫 - CSDN博客

Python标准库笔记(3) — datetime模块 - j_hao104 - 博客园

python datetime - 刘江的python教程

python: time模块、datetime模块 - yumu - CSDN博客

NumPy 教程 | 菜鸟教程

火炬 - PyTorch中文文档

torch.nn 神经网络工具 | AI初学者教程

pytorch loss function 总结 - 张小彬的专栏 - CSDN博客

pytorch教程之损失函数详解——多种定义损失函数的方法 - 神评网

torch.nn.init - PyTorch中文文档

PyTorch 中的数据类型 torch.utils.data.DataLoader - rogerfang的博客 - CSDN博客

torch.utils.data - PyTorch中文文档

torch.optim - PyTorch中文文档

pytorch中的学习率调整函数 - 慢行厚积 - 博客园

[pytorch中文文档] torch.optim - pytorch中文网

PyTorch学习之六个学习率调整策略 - mingo_敏 - CSDN博客

torch.cuda.is_available - daoer_sofu的专栏 - CSDN博客

python路径拼接os.path.join()函数完全教程 - 开贰锤 - CSDN博客

os.mkdir()和os.mkdirs()的区别和用法 - 算法小白 - CSDN博客

pytorch(二)--batch normalization的理解 - tanglinjie的CSDN博客 - CSDN博客

PyTorch参数初始化和Finetune - 知乎

PyTorch 实现中的一些常用技巧-PyTorch 中文网

Pytorch.nn.conv2d 过程验证(单,多通道卷积过程) - 知乎

PyTorch 学习笔记(六):PyTorch的十七个损失函数 - spectre - CSDN博客

PyTorch实战指南 - 知乎

【Python】正则表达式 re.findall 用法 - YZXnuaa的博客 - CSDN博客

Python List append()方法 | 菜鸟教程

Python-基础-时间日期处理小结

Python strftime() - datetime to string

Pytorch的net.train 和 net.eval的使用 - Never-Giveup的博客 - CSDN博客

torch.optim.lr_scheduler.MultiStepLR - qq_41872630的博客 - CSDN博客

Python numpy.transpose 详解 - November、Chopin - CSDN博客

Pytorch(五)入门:DataLoader 和 Dataset - 嘿芝麻的树洞 - CSDN博客

Python time time()方法 | 菜鸟教程

Python enumerate() 函数 | 菜鸟教程

torch代码解析 为什么要使用optimizer.zero_grad() - scut_salmon的博客 - CSDN博客

pytorch学习笔记(1)-optimizer.step()和scheduler.step() - 攻城狮的自我修养 - CSDN博客

Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别 (Pytorch 代码讲解) - xiaoxifei的专栏 - CSDN博客

PyTorch 学习笔记(三):transforms的二十二个方法 - TensorSense的博客 - CSDN博客

python使用numpy读取、保存txt数据 - AManFromEarth的博客 - CSDN博客

Numpy 的一些基础操作必知必会-PyTorch 中文网

measure (measure) - Scikit image 中文开发手册 - 开发者手册 - 云+社区 - 腾讯云

pytorch学习(五)—图像的加载/读取方式 - 简书

matplotlib figure函数学习笔记 - 李啸林的专栏 - CSDN博客

matplotlib.pyplot.figure — Matplotlib 3.1.1 documentation

matplotlib模块数据可视化-图片处理 - sinat_36772813的博客 - CSDN博客

pytorch中图片显示问题 - lighting - CSDN博客

Matplotlib 教程 | 始终

imshow / matshow的插值 - Matplotlib 3.1.1文档

Matplotlib:给子图添加colorbar(颜色条或渐变色条) - 简书

PSNR和SSIM - 文森vincent - 博客园

Python os.listdir() 方法 | 菜鸟教程

Python中startswith和endswith的用法 - Fu4ng - CSDN博客

numpy.random.seed()的使用 - linzch3的博客 - CSDN博客

从np.random.normal()到正态分布的拟合 - Zhang's Wikipedia - CSDN博客

numpy.random.normal函数 - linyi_pk的博客 - CSDN博客

数据格式汇总及type, astype, dtype区别 - 机器学习-深度学习-图像处理-opencv-段子 - CSDN博客

我在读pyTorch文档(二) - aiqiu_gogogo的博客 - CSDN博客

(3条消息)np.reshape()和torch.view() - dspeia的博客 - CSDN博客

Torch张量的view方法有什么作用? - 纯净的天空

pytorch 正确的测试时间的代码 torch.cuda.synchronize() - u013548568的博客 - CSDN博客

PyTorch学习笔记(2)——变量类型(cpu/gpu) - g11d111的博客 - CSDN博客

os.path.splitext(“文件路径”) - 机器学习爱好者 - CSDN博客

numpy数组拼接:stack(),vstack(),hstack()函数使用总结 - Mao_Jonah的博客 - CSDN博客

NumPy 统计函数 | 菜鸟教程

OpenCV Python教程(1、图像的载入、显示和保存) - sunny2038的专栏 - CSDN博客

pytorch实现自由的数据读取-torch.utils.data的学习 - tsq292978891的博客 - CSDN博客

PyTorch—torch.utils.data.DataLoader 数据加载类 - wsp_1138886114的博客 - CSDN博客

Pytorch 04: Pytorch中数据加载---Dataset类和DataLoader类 - 一遍看不懂,我就再看一遍 - CSDN博客

torch.Tensor的4种乘法 - da_kao_la的博客 - CSDN博客

torch.mul() 和 torch.mm() 的区别 - Real_Brilliant的博客 - CSDN博客

PyTorch入门教程(1) - 知乎

pytorch torch张量 - pytorch中文网

PyTorch简明教程 - 李理的博客

x = x.view(x.size(0), -1) 的理解 - whut_ldz的博客 - CSDN博客

Data Augmentation--数据增强解决你有限的数据集 - chang_rj的博客 - CSDN博客

rot90--矩阵旋转 - qq_18343569的博客 - CSDN博客

矩阵的翻转与旋转()(另附代码) - 神评网

Python 中各种imread函数的区别与联系 - Mr. Chen - CSDN博客

opencv imread()方法第二个参数介绍 - qq_27278957的博客 - CSDN博客

位深度、色深的区别以及图片大小的计算 - cc65431362的专栏 - CSDN博客

numpy数据类型 - NumPy 中文文档

5 python numpy.expand_dims的用法 - hxshine的博客 - CSDN博客

numpy中的expand_dims函数 - qm5132的博客 - CSDN博客

(5条消息)numpy的delete删除数组整行和整列 - JamesShawn - CSDN博客

什么是判别模型(Discriminative Model)和生成模型(Generative Model) - zhaoyu106的博客 - CSDN博客

图像去噪算法简介 - InfantSorrow - 博客园

自然图像先验与图像复原 - zbwgycm的博客 - CSDN博客

机器学习概念篇:一文详解凸函数和凸优化,干货满满 - feilong_csdn的博客 - CSDN博客

【图像缩放】双立方(三次)卷积插值 - 程序生涯 - SegmentFault 思否

receptive field,即感受野 - coder - CSDN博客

图像质量评价标准学习笔记(1)-均方误差、峰值信噪比、结构相似性理论、多尺度结构相似性 - weixin_42769131的博客 - CSDN博客

Batch Normalization + Internal Covariate Shift(论文理解) - jason19966的博客 - CSDN博客

torch.nn - PyTorch中文文档

OpenCV图像的基本操作 · OpenCV-Python中文教程 · 看云

[Python] glob 模块(查找文件路径) - 简书

这 5 种计算机视觉技术,刷新你的世界观 | Laravel China 社区

卷积神经网络CNN:Tensorflow实现(以及对卷积特征的可视化) - TinyMind -专注人工智能的技术社区

train_data = train_data.transpose((0,3,1,2))ValueError:轴与数组不匹配·问题#4·ZFTurbo / KAGGLE_DISTRACTED_DRIVER

机器学习 | 王成飞博客

超越图像分类:更多应用深度学习的方法 - 知乎

分类: PyTorch | 从零开始的BLOG

你可能感兴趣的:(深度学习)