实验笔记之——参数量为0.1M的超分网络(octave layer)

本博文为实验笔记

实验python train.py -opt options/train/train_sr.json

先激活虚拟环境source activate pytorch

tensorboard --logdir tb_logger/ --port 6008

浏览器打开http://172.20.36.203:6008/#scalars
 

参数量为176,234

参数量为177,194

 

网络

##############################################################################################
#Octave CARN
class Octave_CARN(nn.Module):#nb=3(3 block),channel=24
    def __init__(self, in_nc, out_nc, nf=24, nc=4, nb=3, alpha=0.75, upscale=4, norm_type=None, act_type='prelu', \
            mode='NAC', res_scale=1, upsample_mode='upconv'):
        super(Octave_CARN, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        self.nb = nb

        self.fea_conv =B.conv_block(in_nc, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode='CNA')

        ##################################################################
        #self.oct_first=B.FirstOctaveConv(nf, nf, kernel_size=3,  alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
        self.oct_first =B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='prelu', mode='CNA')
        ##################################################################
        #self.CascadeBlocks = nn.ModuleList([B.OctaveCascadeBlock(nc, nf, kernel_size=3, alpha=alpha, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
        self.CascadeBlocks = nn.ModuleList([B.CascadeBlock(nc, nf, kernel_size=3, norm_type=norm_type, act_type=act_type, mode=mode, res_scale=res_scale) for _ in range(nb)])
        ##################################################################
        #self.CatBlocks = nn.ModuleList([B.OctaveConv((i + 2)*nf, nf, kernel_size=1, alpha=alpha, norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
        self.CatBlocks = nn.ModuleList([B.conv_block((i + 2)*nf, nf, kernel_size=1, norm_type=norm_type, act_type=act_type, mode=mode) for i in range(nb)])
        ##################################################################
        #self.oct_last = B.LastOctaveConv(nf, nf, kernel_size=3, alpha=alpha, stride=1, dilation=1, groups=1, bias=True, pad_type='zero', norm_type=None, act_type='prelu', mode='CNA')
        self.oct_last = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type='prelu', mode='CNA')
        ##################################################################
        self.upsampler = nn.PixelShuffle(upscale)
        self.HR_conv1 = B.conv_block(nf, in_nc*(upscale ** 2), kernel_size=3, norm_type=None, act_type=None)


    def forward(self, x):
        x = self.fea_conv(x)
        x = self.oct_first(x)
        pre_fea = x
        # for i in range(self.nb):
        #     res = self.CascadeBlocks[i](x)
        #     pre_fea = (torch.cat((pre_fea[0], res[0]), dim=1), \
        #                 torch.cat((pre_fea[1], res[1]), dim=1))
        #     x = self.CatBlocks[i](pre_fea)
        for i in range(self.nb):
            res = self.CascadeBlocks[i](x)
            pre_fea = torch.cat((pre_fea, res), dim=1)
            x = self.CatBlocks[i](pre_fea)
        x = self.oct_last(x)
        x = self.HR_conv1(x)
        x = F.sigmoid(self.upsampler(x))
        return x


# ##############################################################################################

实验结果:

实验笔记之——参数量为0.1M的超分网络(octave layer)_第1张图片

实验说明

1、首先,发现一个现象是:采用DWT代替POOL,performance有所提升。猜测可能pooling会导致一些信息的丢失,而小波变换是可逆变换,特别是alpha为0.75时,performance差异更加明显。

实验笔记之——参数量为0.1M的超分网络(octave layer)_第2张图片

alpha=0时,为32.12dB

2、在CARN(0.1参数量)如下表格所示

实验笔记之——参数量为0.1M的超分网络(octave layer)_第3张图片

alpha=0时,为31.55dB

3、第三个实验比较出乎意料。我用SRResNet的框架(原来是16个block,我现在改为7个)参数量为0.8M左右(原来为1.7M左右),但是却发现加入了octave layer后performance竟然提升了,而且还是0.1dB的提升。那就意味着减少了计算量还提升了performance。

实验笔记之——参数量为0.1M的超分网络(octave layer)_第4张图片

补充实验

实验笔记之——参数量为0.1M的超分网络(octave layer)_第5张图片

实验笔记之——参数量为0.1M的超分网络(octave layer)_第6张图片

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

你可能感兴趣的:(卷积神经网络,超分辨率重建,图像超分辨率重建)