本博文是MWCNN的阅读笔记,论文的链接:https://arxiv.org/pdf/1805.07071.pdf
代码:https://github.com/lpj0/MWCNN(仅仅是matlab代码)
通过参考代码,对该网络在pytorch框架下进行复现
网络结构如下图所示
incorporating residual block in each level of the encoder and decoder(在编码器和解码器中加入residual block)
in each level we adopt discrete wavelet transform (DWT) as the downsamping operator in encoder and inverse wavelet transform (IWT) as upsampling operator in decoder.
And 3 X 3 convolution is deployed to compress and expand the feature maps after DWT and IWT, respectively.
For each level, two residual blocks are further deployed to enhance feature representation and reconstruction.
Image restoration, which aims to recover the latent clean image x from its degraded observation y, is a fundamental and long-standing problem in low level vision.
For image restoration, CNN actually represents a mapping from degraded observation to latent clean image.
one representative strategy is to use the fully convolutional network (FCN) by removing the pooling layers. In general, larger receptive field is helpful to restoration performance by taking more spatial context into account. However, for FCN without pooling, the receptive field size can be enlarged by either increasing the network depth or using filters with larger size, which unexceptionally results in higher computational cost.
dilated filtering 可以用于enlarge receptive field without the sacrifice of computational cost. However,inherently suffers from gridding effect(固有地受到网格效应的影响)where the receptive field only considers a sparse sampling of input image with checkerboard patterns.(其中感受野仅考虑带有棋盘图案的输入图像的稀疏采样)
在LR领域,感受野的大小和效率是一个trade off的关系。普通卷积网络(CNN)通常以牺牲计算成本为代价来扩大感受野。而本文,作者提出的multi-level wavelet CNN (MWCNN) model就是为了更好的在感受野大小和计算效率之间取一个trade off的关系。With the modified U-Net architecture, wavelet transform is introduced to reduce the size of feature maps in the contracting subnetwork.(通过改进的U-Net架构,引入小波变换以减小签约子网中的特征映射的大小。)进一步地再采用一个卷积层来进一步减少特征图的channel的数目。在拓展的子网络中,inverse wavelet transform is then deployed to reconstruct the high resolution feature maps(部署逆小波变换以重建高分辨率特征图。)并且通过扩张滤波和子采样,网络可以应用于其他图像复原的任务中
enlarge receptive field for better tradeoff between performance and efficiency。MWCNN基于U-Net architecture consisting of a contracting subnetwork and an expanding subnetwork(由收缩子网和扩展子网组成。)
在收缩子网络中采用discrete wavelet transform (DWT)以替换每个池操作。由于DWT是可逆的,故此所有的信息都可以被保存,通过这样的一个下采样方案。进一步地,DWT计算feature map的频率与位置信息,这可能有利于恢复细节信息。
In the expanding subnetwork, inverse wavelet transform (IWT) is utilized for upsampling low resolution feature maps to high resolution ones.
To enrich feature representation and reduce computational burden, elementwise summation is adopted for combining the feature maps from the contracting and expanding subnetworks.(为了丰富特征表示并减少计算负担,采用 elementwise summation来结合收缩和扩展子网的特征映射。)
上图表示了WPT对于单幅图像的分解和重构。实际上,WPT是FCN中没有非线性层的特殊情况。而本文的MWCNN就是在WPT的基础删,再增加了卷积层。卷积层位于任意两个level的DWTs中,如下图所示
在每个变换级别之后,将所有子带图像作为CNN块的输入,以学习紧凑表示作为后续变换级别的输入(After each level of transform, all the subband images are taken as the inputs to a CNN block to learn a compact representation as the inputs to the subsequent level of transform.)。MWCNN is a generalization of multi-level WPT, and degrades to WPT when each CNN block becomes the identity mapping.
MWCNN can use subsampling operations safely without information loss。Moreover, compared with conventional CNN, the frequency and location characteristics of DWT is also expected to benefit the preservation of detailed texture.
The key of MWCNN architecture is to design the CNN block after each level of DWT.
每个CNN block有4层全卷积组成(没有池化),将所有子带图像作为输入。相反,不同的CNN被部署到深度卷积小帧中的低频和高频频带(In contrast, different CNNs are deployed to low-frequency and high-frequency bands in deep convolutional framelets)在DWT之后的子带图像仍然是依赖的
Each layer of the CNN block is composed of convolution with 3 X 3 filters (Conv), batch normalization (BN), and rectified linear unit (ReLU) operations. As to the last layer of the last CNN block, Conv without BN and ReLU is adopted to predict residual image.
MWCNN与U-Net的区别
关于在pytorch中实现DWT
https://github.com/t-vi/pytorch-tvmisc/blob/master/misc/2D-Wavelet-Transform.ipynb
https://github.com/fbcotter/pytorch_wavelets
https://pytorch-wavelets.readthedocs.io/en/latest/dwt.html
cd /home/guanwp/BasicSR-master/codes/models/modules/
git clone https://github.com/fbcotter/pytorch_wavelets
cd pytorch_wavelets
pip install .
或者直接使用Pywalvets
https://blog.csdn.net/nanbei2463776506/article/details/64124841
论文中采用haar小波
先给出数据集的下载链接
Waterloo Exploration Database (WED) https://ece.uwaterloo.ca/~k29ma/exploration/
Berkeley Segmentation Dataset (BSD) 200 https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU
DIV2K800是之前博文中一直在用的数据集
function generate_mod_LR_bic()
%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
%% set parameters
% comment the unnecessary line
input_folder = '/home/guanwp/BasicSR_datasets/DIV2K800_sub';
% save_mod_folder = '';
%save_LR_folder = '/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicLRx4';
save_bic_folder = '/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicubic_X4';
up_scale = 4;
mod_scale = 4;
if exist('save_mod_folder', 'var')
if exist(save_mod_folder, 'dir')
disp(['It will cover ', save_mod_folder]);
else
mkdir(save_mod_folder);
end
end
if exist('save_LR_folder', 'var')
if exist(save_LR_folder, 'dir')
disp(['It will cover ', save_LR_folder]);
else
mkdir(save_LR_folder);
end
end
if exist('save_bic_folder', 'var')
if exist(save_bic_folder, 'dir')
disp(['It will cover ', save_bic_folder]);
else
mkdir(save_bic_folder);
end
end
idx = 0;
filepaths = dir(fullfile(input_folder,'*.*'));
for i = 1 : length(filepaths)
[paths,imname,ext] = fileparts(filepaths(i).name);
if isempty(imname)
disp('Ignore . folder.');
elseif strcmp(imname, '.')
disp('Ignore .. folder.');
else
idx = idx + 1;
str_rlt = sprintf('%d\t%s.\n', idx, imname);
fprintf(str_rlt);
% read image
img = imread(fullfile(input_folder, [imname, ext]));
img = im2double(img);
% modcrop
img = modcrop(img, mod_scale);
if exist('save_mod_folder', 'var')
imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
end
% LR
im_LR = imresize(img, 1/up_scale, 'bicubic');
if exist('save_LR_folder', 'var')
imwrite(im_LR, fullfile(save_LR_folder, [imname, '_bicLRx4.png']));
end
% Bicubic
if exist('save_bic_folder', 'var')
im_B = imresize(im_LR, up_scale, 'bicubic');
imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png']));
end
end
end
end
%% modcrop
function img = modcrop(img, modulo)
if size(img,3) == 1
sz = size(img);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2));
else
tmpsz = size(img);
sz = tmpsz(1:2);
sz = sz - mod(sz, modulo);
img = img(1:sz(1), 1:sz(2),:);
end
end
sub512,stride512/2
{
"name": "MWCNN_DATA" //"001_RRDB_PSNR_x4_DIV2K" // please remove "debug_" during training or tensorboard wounld not work
,
"use_tb_logger": true,
"model": "sr",
//"crop_scale": 0,
"scale": 1//it must be 1
,
"gpu_ids": [4,5],
"datasets": {
"train": {
"name": "MWCNN_DATA",
"mode": "LRHR" //it must be this, and the detail would be shown in LRHR_dataset.py
//, "noise_get": true///
,
"dataroot_HR": "/home/guanwp/BasicSR_datasets/MWCNN_data_sub" ///must be sub
,
"dataroot_LR": "/home/guanwp/BasicSR_datasets/MWCNN_data_sub_bicubic_X4",
"subset_file": null,
"use_shuffle": true,
"n_workers": 8,
"batch_size": 24//16//32 //how many samples in each iters
,
"HR_size": 128// 128 | 192
,
"use_flip": false //true//
,
"use_rot": false //true
},
"val": {
"name": "Set5",
"mode": "LRHR",
"dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",
"dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_bicubic_X4"
//, "noise_get": true///this is important
}
},
"path": {
"root": "/home/guanwp/BasicSR-master/",
"pretrain_model_G": null,
"experiments_root": "/home/guanwp/BasicSR-master/experiments/",
"models": "/home/guanwp/BasicSR-master/experiments/MWCNN_DATA/models",
"log": "/home/guanwp/BasicSR-master/experiments/MWCNN_DATA",
"val_images": "/home/guanwp/BasicSR-master/experiments/MWCNN_DATA/val_images"
},
"network_G": {
"which_model_G":"mwcnn"//"noise_estimation" //"espcn"//"srresnet"//"sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet
,
"norm_type": null,
"mode": "CNA",
"nf": 64 //56//64
,
"nb": 16,//number of residual block
"in_nc": 3,
"out_nc": 3,
"gc": 32,
"group": 1
},
"train": {
"lr_G": 1e-3//8e-4 //1e-3//2e-4
,
"lr_scheme": "MultiStepLR",
"lr_steps": [300000,400000,600000,800000,1000000],
"lr_gamma": 0.5,
"pixel_criterion": "l2" //"l2_tv"//"l1"//'l2'//huber//Cross //should be MSE LOSS
,
"pixel_weight": 1.0,
"val_freq": 1e3,
"manual_seed": 0,
"niter": 1.2e6 //2e6//1e6
},
"logger": {
"print_freq": 200,
"save_checkpoint_freq": 1e3
}
}
{
"name": "MWCNN_DIVIK" //"001_RRDB_PSNR_x4_DIV2K" // please remove "debug_" during training or tensorboard wounld not work
,
"use_tb_logger": true,
"model": "sr",
//"crop_scale": 0,
"scale": 1//it must be 1
,
"gpu_ids": [4,5],
"datasets": {
"train": {
"name": "DIV2K80",
"mode": "LRHR" //it must be this, and the detail would be shown in LRHR_dataset.py
//, "noise_get": true///
,
"dataroot_HR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub" ///must be sub
,
"dataroot_LR": "/home/guanwp/BasicSR_datasets/DIV2K800_sub_bicubic_X4",
"subset_file": null,
"use_shuffle": true,
"n_workers": 8,
"batch_size": 16//32 //how many samples in each iters
,
"HR_size": 128// 128 | 192
,
"use_flip": false //true//
,
"use_rot": false //true
},
"val": {
"name": "Set5",
"mode": "LRHR",
"dataroot_HR": "/home/guanwp/BasicSR_datasets/val_set5/Set5",
"dataroot_LR": "/home/guanwp/BasicSR_datasets/val_set5/Set5_bicubic_X4"
//, "noise_get": true///this is important
}
},
"path": {
"root": "/home/guanwp/BasicSR-master/",
"pretrain_model_G": null,
"experiments_root": "/home/guanwp/BasicSR-master/experiments/",
"models": "/home/guanwp/BasicSR-master/experiments/MWCNN_DIVIK/models",
"log": "/home/guanwp/BasicSR-master/experiments/MWCNN_DIVIK",
"val_images": "/home/guanwp/BasicSR-master/experiments/MWCNN_DIVIK/val_images"
},
"network_G": {
"which_model_G":"mwcnn"//"noise_estimation" //"espcn"//"srresnet"//"sr_resnet"//"fsrcnn"//"sr_resnet" // RRDB_net | sr_resnet
,
"norm_type": null,
"mode": "CNA",
"nf": 64 //56//64
,
"nb": 16,//number of residual block
"in_nc": 3,
"out_nc": 3,
"gc": 32,
"group": 1
},
"train": {
"lr_G": 1e-3//8e-4 //1e-3//2e-4
,
"lr_scheme": "MultiStepLR",
"lr_steps": [300000,400000,600000,800000,1000000],
"lr_gamma": 0.5,
"pixel_criterion": "l2" //"l2_tv"//"l1"//'l2'//huber//Cross //should be MSE LOSS
,
"pixel_weight": 1.0,
"val_freq": 1e3,
"manual_seed": 0,
"niter": 1.2e6 //2e6//1e6
},
"logger": {
"print_freq": 200,
"save_checkpoint_freq": 1e3
}
}
在network中加入
#############################################################################################################
elif which_model=='mwcnn':#MWCNN
netG=arch.MWCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')
#############################################################################################################
网络结构如下
#######################################################################################################3
class Block_of_DMT1(nn.Module):
def __init__(self):
super(Block_of_DMT1,self).__init__()
#DMT1
self.conv1_1=nn.Conv2d(in_channels=160,out_channels=160,kernel_size=3,stride=1,padding=1)
self.bn1_1=nn.BatchNorm2d(160, affine=True)
self.relu1_1=nn.ReLU()
def forward(self, x):
output = self.relu1_1(self.bn1_1(self.conv1_1(x)))
return output
class Block_of_DMT2(nn.Module):
def __init__(self):
super(Block_of_DMT2,self).__init__()
#DMT1
self.conv2_1=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1)
self.bn2_1=nn.BatchNorm2d(256, affine=True)
self.relu2_1=nn.ReLU()
def forward(self, x):
output = self.relu2_1(self.bn2_1(self.conv2_1(x)))
return output
class Block_of_DMT3(nn.Module):
def __init__(self):
super(Block_of_DMT3,self).__init__()
#DMT1
self.conv3_1=nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1)
self.bn3_1=nn.BatchNorm2d(256, affine=True)
self.relu3_1=nn.ReLU()
def forward(self, x):
output = self.relu3_1(self.bn3_1(self.conv3_1(x)))
return output
#MWCNN
class MWCNN(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, upscale=2, norm_type='batch', act_type='relu', \
mode='NAC', res_scale=1, upsample_mode='upconv'):##play attention the upscales
super(MWCNN,self).__init__()
self.DWT= DWTForward(J=1, wave='haar').cuda()
self.IDWT=DWTInverse(wave='haar').cuda()
#DMT1 operation
#DMT1
self.conv_DMT1=nn.Conv2d(in_channels=3*4,out_channels=160,kernel_size=3,stride=1,padding=1)
self.bn_DMT1=nn.BatchNorm2d(160, affine=True)
self.relu_DMT1=nn.ReLU()
#IDMT1
self.conv_IDMT1=nn.Conv2d(in_channels=160,out_channels=3*4,kernel_size=3,stride=1,padding=1)
self.blockDMT1=self.make_layer(Block_of_DMT1,3)
#DMT2 operation
#DMT2
self.conv_DMT2=nn.Conv2d(in_channels=640,out_channels=256,kernel_size=3,stride=1,padding=1)
self.bn_DMT2=nn.BatchNorm2d(256, affine=True)
self.relu_DMT2=nn.ReLU()
#IDMT2
self.conv_IDMT2=nn.Conv2d(in_channels=256,out_channels=640,kernel_size=3,stride=1,padding=1)
self.bn_IDMT2=nn.BatchNorm2d(640, affine=True)
self.relu_IDMT2=nn.ReLU()
self.blockDMT2=self.make_layer(Block_of_DMT2,3)
#DMT3 operation
#DMT3
self.conv_DMT3=nn.Conv2d(in_channels=1024,out_channels=256,kernel_size=3,stride=1,padding=1)
self.bn_DMT3=nn.BatchNorm2d(256, affine=True)
self.relu_DMT3=nn.ReLU()
#IDMT3
self.conv_IDMT3=nn.Conv2d(in_channels=256,out_channels=1024,kernel_size=3,stride=1,padding=1)
self.bn_IDMT3=nn.BatchNorm2d(1024, affine=True)
self.relu_IDMT3=nn.ReLU()
self.blockDMT3=self.make_layer(Block_of_DMT3,3)
def make_layer(self, block, num_of_layer):
layers = []
for _ in range(num_of_layer):
layers.append(block())
return nn.Sequential(*layers)
def _transformer(self, DMT1_yl, DMT1_yh):
list_tensor = []
for i in range(3):
list_tensor.append(DMT1_yh[0][:,:,i,:,:])
list_tensor.append(DMT1_yl)
return torch.cat(list_tensor, 1)
def _Itransformer(self,out):
#w = pywt.Wavelet('haar')
yh = []
C=out.shape[1]/4
#sz=2*(len(w.dec_lo) // 2 - 1)
#if yl.shape[-2] % 2 == 1 and yl.shape[-1] % 2 == 1:
#yl = F.pad(yl, (sz, sz+1, sz, sz+1), mode='reflect')
#elif yl.shape[-2] % 2 == 1:
#yl = F.pad(yl, (sz, sz+1, sz, sz), mode='reflect')
#elif yl.shape[-1] % 2 == 1:
#yl = F.pad(yl, (sz, sz, sz, sz+1), mode='reflect')
#else:
#yl = F.pad(yl, (sz, sz, sz, sz), mode='reflect')
y = out.reshape((out.shape[0], C, 4, out.shape[-2], out.shape[-1]))
yl = y[:,:,0].contiguous()
yh.append(y[:,:,1:].contiguous())
return yl, yh
def forward(self, x):#
DMT1_p=x
#DMT1
DMT1_yl,DMT1_yh = self.DWT(x)
DMT1 = self._transformer(DMT1_yl, DMT1_yh)
out=self.relu_DMT1(self.bn_DMT1(self.conv_DMT1(DMT1)))
out=self.blockDMT1(out)###160
DMT2_p=out
#DMT2
DMT2_yl, DMT2_yh=self.DWT(out)
DMT2=self._transformer(DMT2_yl, DMT2_yh)
out=self.relu_DMT2(self.bn_DMT2(self.conv_DMT2(DMT2)))
out=self.blockDMT2(out)###256
DMT3_p=out
#DMT3
DMT3_yl, DMT3_yh=self.DWT(out)
DMT3=self._transformer(DMT3_yl, DMT3_yh)
out=self.relu_DMT3(self.bn_DMT3(self.conv_DMT3(DMT3)))
out=self.blockDMT3(out)###256
#IDMT3
out=self.blockDMT3(out)#DMT4
out=self.relu_IDMT3(self.bn_IDMT3(self.conv_IDMT3(out)))
out=self._Itransformer(out)###########
IDMT3=self.IDWT(out)
out=IDMT3+DMT3_p
#IDMT2
out=self.blockDMT2(out)
out=self.relu_IDMT2(self.bn_IDMT2(self.conv_IDMT2(out)))
out=self._Itransformer(out)##############
IDMT2=self.IDWT(out)
out=IDMT2+DMT2_p
#IDMT1
out=self.blockDMT1(out)
out=self.conv_IDMT1(out)
out=self._Itransformer(out)###############
IDMT1=self.IDWT(out)
out=IDMT1+DMT1_p
return out
##########################################