# -*- coding: utf-8 -*-
# =============================================================================
# @article{zhang2017beyond,
# title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
# author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
# journal={IEEE Transactions on Image Processing},
# year={2017},
# volume={26},
# number={7},
# pages={3142-3155},
# }
# by Kai Zhang (08/2018)
# [email protected]
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================
# run this to test the model
import argparse ## python的参数解析argparse模块
import os, time, datetime # #文件名操作模块glob # time, datetime 时间模块
# import PIL.Image as Image
import numpy as np #导入numpy
import torch.nn as nn #torch.nn的核心数据结构是Module
import torch.nn.init as init #初始化
import torch #包 torch 包含了多维张量的数据结构以及基于其上的多种数学操作。
from skimage.measure import compare_psnr, compare_ssim #计算图像的峰值信噪比(PSNR) #计算两幅图像之间的平均结构相似性指数。
from skimage.io import imread, imsave #在python中,图像处理主要采用的库:skimage, opencv-python, Pillow (PIL)。 这三个库均提供了图像读取的方法。
#创建解析器
def parse_args():
parser = argparse.ArgumentParser()
#增加参数
# 测试集 data/Test
parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
#测试集名字 set68 set12
parser.add_argument('--set_names', default=['Set68', 'Set12'], help='directory of test dataset')
#噪声水平 默认25
parser.add_argument('--sigma', default=25, type=int, help='noise level')
#模型位置 models/DnCNN_sigma25
parser.add_argument('--model_dir', default=os.path.join('models', 'DnCNN_sigma25'), help='directory of the model')
#模型名字 默认model_001.pth
parser.add_argument('--model_name', default='model_001.pth', type=str, help='the model name')
#结果位置 results
parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
#保存结果 1保存 0否
parser.add_argument('--save_result', default=0, type=int, help='save the denoised image, 1 or 0')
return parser.parse_args() #解析参数并返回
##strftime()方法使用日期,时间或日期时间对象返回表示日期和时间的字符串 日志
def log(*args, **kwargs):
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
#保存结果
def save_result(result, path):
path = path if path.find('.') != -1 else path+'.png'
ext = os.path.splitext(path)[-1]
if ext in ('.txt', '.dlm'):
np.savetxt(path, result, fmt='%2.4f')
else:
imsave(path, np.clip(result, 0, 1))
#展示图片
def show(x, title=None, cbar=False, figsize=None):
import matplotlib.pyplot as plt
plt.figure(figsize=figsize) # 图像的长和宽(英寸) #Figure返回的实例也将传递给后端的new_figure_manage
plt.imshow(x, interpolation='nearest', cmap='gray') #interpolation 插值方法 #cmap: 颜色图谱(colormap), 默认绘制为RGB(A)颜色空间
if title:
plt.title(title) #标题
if cbar:
plt.colorbar() #将颜色条添加到绘图中。
plt.show()
class DnCNN(nn.Module):
def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
layers = []
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
self._initialize_weights()
def forward(self, x):
y = x
out = self.dncnn(x)
return y-out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.orthogonal_(m.weight)
print('init weight')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
if __name__ == '__main__':
#调用解析器
args = parse_args()
# model = DnCNN()
#如果路径下的模型不存在models/DnCNN_sigma25/model_001.pth
if not os.path.exists(os.path.join(args.model_dir, args.model_name)):
#加载路径下的模型models/DnCNN_sigma25/model.pth
model = torch.load(os.path.join(args.model_dir, 'model.pth'))
# load weights into new model
#加载源代码中自带的训练模型
log('load trained model on Train400 dataset by kai')
else:
# model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
#加载自己训练的模型
model = torch.load(os.path.join(args.model_dir, args.model_name))
log('load trained model')
# params = model.state_dict()
# print(params.values())
# print(params.keys())
#
# for key, value in params.items():
# print(key) # parameter name
# print(params['dncnn.12.running_mean'])
# print(model.state_dict())
#测试模型时会在前面加上
model.eval() # evaluation mode
# model.train()
#判断GPU是否可用 可用的话,model放在GPU上测试
if torch.cuda.is_available():
model = model.cuda()
#如果存放结果文件夹不存在,创建文件夹
if not os.path.exists(args.result_dir):
os.mkdir(args.result_dir)
for set_cur in args.set_names:
#如果存放结果的文件夹不存在,创建Test/set12 Test/set68
if not os.path.exists(os.path.join(args.result_dir, set_cur)):
os.mkdir(os.path.join(args.result_dir, set_cur))
psnrs = [] #峰值信噪比
ssims = [] #结构相似性
#args.set_dir:data/Test
#os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表
for im in os.listdir(os.path.join(args.set_dir, set_cur)):
#endswith() 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回True,否则返回False。
if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
#np.array将读入的图像转换为ndarray形式,且使像素值处于[0 1]
x = np.array(imread(os.path.join(args.set_dir, set_cur, im)), dtype=np.float32)/255.0
#seed( ) 用于指定随机数生成时所用算法开始的整数值,如果使用相同的seed( )值,则每次生成的随即数都相同,
np.random.seed(seed=0) # 随机生成一个种子 for reproducibility
#numpy.random.normal函数,有三个参数(loc, scale, size),分别l代表生成的高斯分布的随机数的均值、方差以及输出的size.
y = x + np.random.normal(0, args.sigma/255.0, x.shape) # Add Gaussian noise without clipping
#dtype 用于查看数据类型 astype 用于转换数据类型
y = y.astype(np.float32)
#将numpy矩阵y转化为torch张量y_,共享内存
#np.reshape()和torch.view()效果一样,reshape()操作nparray,view()操作tensor
y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])
#正确的测试时间的代码 torch.cuda.synchronize()
torch.cuda.synchronize()
#开始时间
start_time = time.time()
#y_放在GPU
y_ = y_.cuda()
#调用模型
x_ = model(y_) # inference
#把数据二维y.shape[0]* y.shape[1]
x_ = x_.view(y.shape[0], y.shape[1])
#若从gpu –> cpu,则使用data.cpu()
x_ = x_.cpu()
#detach()返回一个新的 从当前图中分离的 Variable,转换为numpy,然后类型转换
x_ = x_.detach().numpy().astype(np.float32)
# 正确的测试时间的代码
torch.cuda.synchronize()
#运行时间
elapsed_time = time.time() - start_time
#输出测试集名字 :图片名 : 运行时间
print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))
psnr_x_ = compare_psnr(x, x_) #skimage.measure库里的compare_psnr方法
ssim_x_ = compare_ssim(x, x_) #skimage.measure库里的compare_ssim方法
#如果存放结果的文件夹存在,那么。。。
if args.save_result:
#分离文件名与扩展名
name, ext = os.path.splitext(im)
#调用自定义函数show()展示图片
#numpy.hstack()函数是将数组沿水平方向堆叠起来
show(np.hstack((y, x_))) # show the image
#调用保存结果函数 参数(X_, 路径results/set12/文件名_dncnn.后缀)
save_result(x_, path=os.path.join(args.result_dir, set_cur, name+'_dncnn'+ext)) # save the denoised image
psnrs.append(psnr_x_) #保存PSNR
ssims.append(ssim_x_) #保存SSIM
#numpy.mean() 函数返回数组中元素的算术平均值
psnr_avg = np.mean(psnrs)
ssim_avg = np.mean(ssims)
#append() 方法用于在列表末尾添加新的对象。
psnrs.append(psnr_avg)
ssims.append(ssim_avg)
#如果文件夹存在
if args.save_result:
#保存(psnr和ssims值,放在文件results.txt里面)
save_result(np.hstack((psnrs, ssims)), path=os.path.join(args.result_dir, set_cur, 'results.txt'))
#打印日志,调用log函数
log('Datset: {0:10s} \n PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))
argparse --- 命令行选项、参数和子命令解析器
re --- 正则表达式操作 — Python 3.7.4 文档
Python正则表达式指南 - AstralWind - 博客园
Python 标准库——os、glob模块 - 温柔一cai刀 - CSDN博客
python glob模块 - 火星大熊猫 - CSDN博客
Python标准库笔记(3) — datetime模块 - j_hao104 - 博客园
python datetime - 刘江的python教程
python: time模块、datetime模块 - yumu - CSDN博客
NumPy 教程 | 菜鸟教程
火炬 - PyTorch中文文档
torch.nn 神经网络工具 | AI初学者教程
pytorch loss function 总结 - 张小彬的专栏 - CSDN博客
pytorch教程之损失函数详解——多种定义损失函数的方法 - 神评网
torch.nn.init - PyTorch中文文档
PyTorch 中的数据类型 torch.utils.data.DataLoader - rogerfang的博客 - CSDN博客
torch.utils.data - PyTorch中文文档
torch.optim - PyTorch中文文档
pytorch中的学习率调整函数 - 慢行厚积 - 博客园
[pytorch中文文档] torch.optim - pytorch中文网
PyTorch学习之六个学习率调整策略 - mingo_敏 - CSDN博客
torch.cuda.is_available - daoer_sofu的专栏 - CSDN博客
python路径拼接os.path.join()函数完全教程 - 开贰锤 - CSDN博客
os.mkdir()和os.mkdirs()的区别和用法 - 算法小白 - CSDN博客
pytorch(二)--batch normalization的理解 - tanglinjie的CSDN博客 - CSDN博客
PyTorch参数初始化和Finetune - 知乎
PyTorch 实现中的一些常用技巧-PyTorch 中文网
Pytorch.nn.conv2d 过程验证(单,多通道卷积过程) - 知乎
PyTorch 学习笔记(六):PyTorch的十七个损失函数 - spectre - CSDN博客
PyTorch实战指南 - 知乎
【Python】正则表达式 re.findall 用法 - YZXnuaa的博客 - CSDN博客
Python List append()方法 | 菜鸟教程
Python-基础-时间日期处理小结
Python strftime() - datetime to string
Pytorch的net.train 和 net.eval的使用 - Never-Giveup的博客 - CSDN博客
torch.optim.lr_scheduler.MultiStepLR - qq_41872630的博客 - CSDN博客
Python numpy.transpose 详解 - November、Chopin - CSDN博客
Pytorch(五)入门:DataLoader 和 Dataset - 嘿芝麻的树洞 - CSDN博客
Python time time()方法 | 菜鸟教程
Python enumerate() 函数 | 菜鸟教程
torch代码解析 为什么要使用optimizer.zero_grad() - scut_salmon的博客 - CSDN博客
pytorch学习笔记(1)-optimizer.step()和scheduler.step() - 攻城狮的自我修养 - CSDN博客
Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别 (Pytorch 代码讲解) - xiaoxifei的专栏 - CSDN博客
PyTorch 学习笔记(三):transforms的二十二个方法 - TensorSense的博客 - CSDN博客
python使用numpy读取、保存txt数据 - AManFromEarth的博客 - CSDN博客
Numpy 的一些基础操作必知必会-PyTorch 中文网
measure (measure) - Scikit image 中文开发手册 - 开发者手册 - 云+社区 - 腾讯云
pytorch学习(五)—图像的加载/读取方式 - 简书
matplotlib figure函数学习笔记 - 李啸林的专栏 - CSDN博客
matplotlib.pyplot.figure — Matplotlib 3.1.1 documentation
matplotlib模块数据可视化-图片处理 - sinat_36772813的博客 - CSDN博客
pytorch中图片显示问题 - lighting - CSDN博客
Matplotlib 教程 | 始终
imshow / matshow的插值 - Matplotlib 3.1.1文档
Matplotlib:给子图添加colorbar(颜色条或渐变色条) - 简书
PSNR和SSIM - 文森vincent - 博客园
Python os.listdir() 方法 | 菜鸟教程
Python中startswith和endswith的用法 - Fu4ng - CSDN博客
numpy.random.seed()的使用 - linzch3的博客 - CSDN博客
从np.random.normal()到正态分布的拟合 - Zhang's Wikipedia - CSDN博客
numpy.random.normal函数 - linyi_pk的博客 - CSDN博客
数据格式汇总及type, astype, dtype区别 - 机器学习-深度学习-图像处理-opencv-段子 - CSDN博客
我在读pyTorch文档(二) - aiqiu_gogogo的博客 - CSDN博客
(3条消息)np.reshape()和torch.view() - dspeia的博客 - CSDN博客
Torch张量的view方法有什么作用? - 纯净的天空
pytorch 正确的测试时间的代码 torch.cuda.synchronize() - u013548568的博客 - CSDN博客
PyTorch学习笔记(2)——变量类型(cpu/gpu) - g11d111的博客 - CSDN博客
os.path.splitext(“文件路径”) - 机器学习爱好者 - CSDN博客
numpy数组拼接:stack(),vstack(),hstack()函数使用总结 - Mao_Jonah的博客 - CSDN博客
NumPy 统计函数 | 菜鸟教程
OpenCV Python教程(1、图像的载入、显示和保存) - sunny2038的专栏 - CSDN博客
pytorch实现自由的数据读取-torch.utils.data的学习 - tsq292978891的博客 - CSDN博客
PyTorch—torch.utils.data.DataLoader 数据加载类 - wsp_1138886114的博客 - CSDN博客
Pytorch 04: Pytorch中数据加载---Dataset类和DataLoader类 - 一遍看不懂,我就再看一遍 - CSDN博客
torch.Tensor的4种乘法 - da_kao_la的博客 - CSDN博客
torch.mul() 和 torch.mm() 的区别 - Real_Brilliant的博客 - CSDN博客
PyTorch入门教程(1) - 知乎
pytorch torch张量 - pytorch中文网
PyTorch简明教程 - 李理的博客
x = x.view(x.size(0), -1) 的理解 - whut_ldz的博客 - CSDN博客
Data Augmentation--数据增强解决你有限的数据集 - chang_rj的博客 - CSDN博客
rot90--矩阵旋转 - qq_18343569的博客 - CSDN博客
矩阵的翻转与旋转()(另附代码) - 神评网
Python 中各种imread函数的区别与联系 - Mr. Chen - CSDN博客
opencv imread()方法第二个参数介绍 - qq_27278957的博客 - CSDN博客
位深度、色深的区别以及图片大小的计算 - cc65431362的专栏 - CSDN博客
numpy数据类型 - NumPy 中文文档
5 python numpy.expand_dims的用法 - hxshine的博客 - CSDN博客
numpy中的expand_dims函数 - qm5132的博客 - CSDN博客
(5条消息)numpy的delete删除数组整行和整列 - JamesShawn - CSDN博客
什么是判别模型(Discriminative Model)和生成模型(Generative Model) - zhaoyu106的博客 - CSDN博客
图像去噪算法简介 - InfantSorrow - 博客园
自然图像先验与图像复原 - zbwgycm的博客 - CSDN博客
机器学习概念篇:一文详解凸函数和凸优化,干货满满 - feilong_csdn的博客 - CSDN博客
【图像缩放】双立方(三次)卷积插值 - 程序生涯 - SegmentFault 思否
receptive field,即感受野 - coder - CSDN博客
图像质量评价标准学习笔记(1)-均方误差、峰值信噪比、结构相似性理论、多尺度结构相似性 - weixin_42769131的博客 - CSDN博客
Batch Normalization + Internal Covariate Shift(论文理解) - jason19966的博客 - CSDN博客
torch.nn - PyTorch中文文档
OpenCV图像的基本操作 · OpenCV-Python中文教程 · 看云
[Python] glob 模块(查找文件路径) - 简书
这 5 种计算机视觉技术,刷新你的世界观 | Laravel China 社区
卷积神经网络CNN:Tensorflow实现(以及对卷积特征的可视化) - TinyMind -专注人工智能的技术社区
train_data = train_data.transpose((0,3,1,2))ValueError:轴与数组不匹配·问题#4·ZFTurbo / KAGGLE_DISTRACTED_DRIVER
机器学习 | 王成飞博客
超越图像分类:更多应用深度学习的方法 - 知乎
分类: PyTorch | 从零开始的BLOG