论文笔记:CBDNet图像去噪网络

论文笔记:CBDNet图像去噪网络_第1张图片


文章目录

  • 前言
    • 传统深度卷积神经网络去噪方法的局限性
    • 本篇论文的改进方法
    • 其主要贡献在于以下几点:
  • 一、论文介绍
    • 噪声模型
    • 网络结构
    • 损失函数
      • 1.非对称损失 L a s y m m \mathcal{L}_{asymm} Lasymm
      • 2.全变分损失
      • 3.重建损失 L r e c \mathcal{L}_{rec} Lrec
    • 训练
  • 二、代码实现
    • 2.定义 CBDNet 网络
      • 2.1 FCN部分
    • 2.2UNet部分
    • 2.3 CBDNet 分析
    • 2.4损失函数
    • 2.4 训练
  • 结果


前言

    这是哈工大与香港理工大Lei Zhang老师课题组合作完成的论文,这两个团队在图像去噪方面一直走在前沿,许多经典工作都是他们提出的,如WNNM、DnCNN等。这一篇也是其在深度图像去噪方面的新的文章。与其前面的工作不同的是,以前的图像去噪大多使用合成数据,这篇文章研究了CNN在真实图像上的去噪效果,

传统深度卷积神经网络去噪方法的局限性

    1.大多数存在的盲去噪的方法都包括两步:噪声估计和非盲目去噪。
    2.深度卷积神经网络的效果依赖于训练数据,但真实噪声图像和干净图像太少,而合成的噪声图像与真实噪声图像相差太大。
    3.真实噪声的特征不能充分地被设计的噪声模型所刻画。
    4.非盲目去噪器(BM3D、FFDNet)对低估噪声等级敏感,而对高估噪声等级表现良好。即在噪声估计网络对噪声图像的噪声估计的噪声等级低于实际噪声等级时,去噪效果不好,但当噪声估计网络对噪声图像的噪声估计的噪声的呢估计高于实际噪声等级时去噪效果良好。

本篇论文的改进方法

    1.针对第一点,同样分为两个子网络:噪声估计子网络和非盲目去噪子网络
    2.针对第二点和第三点,论文选择同时用合成噪声图像和真实噪声图像交替训练网络 。
    3.针对第三点,论文提出了一个更接近真实噪声的模型,既考虑了信号相关的噪声,又考虑了摄像机的处理流水线中的噪声。
    4.针对第四点,论文充分利用BM3D对高估计噪声等级表现良好的特性,选择用非对称的方法来学习,即的那个噪声估计网络高估噪声时,给与一个较小的惩罚,而的那个网络低估噪声等级时,给予较大的惩罚。

其主要贡献在于以下几点:

  • 提出了一个更加真实的噪声模型,其考虑了信号依赖噪声和ISP流程对噪声的影响,展示了图像噪声模型在真实噪声图像中起着关键作用。
  • 提出了CBDNet模型,其包括了一个噪声估计子网络和一个非盲去噪子网络,可以实现图像的盲去噪(即未知噪声水平)。
  • 提出了非对称学习(asymmetric learning)的损失函数,并允许用户交互式调整去噪结果,增强了去噪结果的鲁棒性。
  • 将合成噪声图像与真实噪声图像一起用于网络的训练,提升网络的去噪效果和泛化能力。

一、论文介绍

    CBDNet这篇文章针对的则是模型在真实噪声上效果差的问题,使得去噪不再局限于较理想化的高斯噪声。传统CNN去噪模型的效果很大程度上取决于合成噪声和实际噪声的分布是否匹配,于是本文的去噪模型分为两阶段——第一阶段进行噪声估计,第二阶段将噪声估计结果与噪声图一并作为输入进行非盲去噪。

噪声模型

对噪声进行建模是为了生成去噪网络的训练集,建模越趋近真实噪声后续去噪效果也便越好。对一个真是图像来说,除了高斯噪声,图片的其它噪声更加复杂,并且是信号依赖的。

  • 给定一个干净图片 x,一个更加真实的噪声模型 n ( x ) ∼ N ( 0 , σ ( y ) ) n(x) \sim \mathcal{N}(0,\sigma(y)) n(x)N(0,σ(y))可以表示为: σ 2 ( x ) = x ⋅ σ s 2 + σ c 2 \sigma^2(x)=x\cdot\sigma^2_s+\sigma^2_c σ2(x)=xσs2+σc2。本文用的是异方差高斯分布,方差分成了依赖于信号的部分和平稳噪声的部分,其中 n s n_s ns是信号依赖的噪声, n c n_c nc是平稳的静态噪声分量。静态噪声分量 n c n_c nc常常建模为方差为 σ c 2 σ^2_c σc2的高斯白噪声, n s n_s ns则和图像的像素值有关,比如 x ( i ) ⋅ σ s 2 x(i)\cdot\sigma^2_s x(i)σs2
  • 除此之外,文章还进一步考虑了相机内ISP流程,其导致了下面的信号依赖和颜色通道依赖的噪声模型 y = M − 1 ( M ( f ( L + n ( x ) ) ) ) y=M^{-1}(M(f(L+n(x)))) y=M1(M(f(L+n(x))))。其中, y y y表示合成的噪声图像, f ( ⋅ ) f(\cdot) f()代表了相机响应函数(CRF),其将辐照度L转化为原始干净图像x。 M ( ⋅ ) M(\cdot) M()表示将sRGB图像转化为Bayer图像的函数, M − 1 ( ⋅ ) M^{-1}(\cdot) M1()表示去马赛克函数,原本用于去马赛克的插值方法,用它的目的是使噪声空间和颜色相关,从而增加噪声的复杂性,去马赛克函数中的线性插值运算涉及到了不同颜色通道的像素,所以合成的噪声是通道依赖的。
  • 此外,为了扩展到对压缩图片的处理,我们把 JPEG 压缩也考虑进合成图片的生成过程。 y = J P E G ( M − 1 ( M ( f ( L + n ( x ) ) ) ) ) y=JPEG(M^{-1}(M(f(L+n(x))))) y=JPEG(M1(M(f(L+n(x)))))
  • 对于噪声RAW图像,采用第一个公式合成图像。对于噪声未压缩图像,采用第二个公式来合成图像。对于噪声压缩图像,采用第三个公式来合成图像。

网络结构

论文笔记:CBDNet图像去噪网络_第2张图片
整体架构可以看到,网络由一个 全卷积网络FCN,和一个 UNet 组成。
CBDNet包含了两个子网络:噪声估计子网络和非盲去噪子网络。

  •     首先,噪声估计子网络将噪声观测图像y转换为估计的噪声水平图 σ ^ ( y ) \hat{\sigma}(y) σ^(y)。然后,非盲去噪子网络将 y y y σ ^ ( y ) \hat{\sigma}(y) σ^(y)作为输入得到最终的去噪结果 x ^ \hat{x} x^。除此之外,噪声估计子网络允许用户在估计的噪声水平图 σ ^ ( y ) \hat{\sigma}(y) σ^(y)输入到非盲去噪子网络之前对应进行调整。文章提出了一种简单的策略 ϱ ^ = σ ^ ( y ) \hat{ϱ}=\hat{\sigma}(y) ϱ^=σ^(y)
       噪声估计子网络使用五层全卷积网络,卷积核为3×3×32,并且不进行pooling和batch normalization。
  •    非盲去噪子网络使用16层的U-Net结构,且使用残差学习的方式学习残差映射 R ( y , σ ^ ( y ) ; W D ) \mathcal{R}(y,\hat{\sigma}(y);W_D) R(y,σ^(y);WD),从而得到干净的图像 x ^ = y + R ( y , σ ^ ( y ) ; W D ) \hat{x}=y+\mathcal{R}(y,\hat{\sigma}(y);W_D) x^=y+R(y,σ^(y);WD)。使用FCN进行噪声估计后,输出的noise level map与噪声图尺寸一致,加上简单起见所以没再对其进行调整,直接作为噪声估计结果送至下阶段网络。这部分网络是通过已知噪声水平的合成图像进行监督训练的。非盲去噪网络CNND采用的则是U-Net+残差的结构。
       另外,虽然在DnCNN中提到,batch normalization成功应用于高斯去噪中,但是对于真实图像的噪声去除并没有多大帮助,这可能是由于真实世界的噪声分别与高斯分布相差较大。

损失函数

在这里插入图片描述

1.非对称损失 L a s y m m \mathcal{L}_{asymm} Lasymm

   作者观察到非盲去噪方法(如BM3D、FFDNet等)对噪声估计的误差具有非对称敏感性(the asymmetric sensitivity of non-blind denoisers)。如下图所示,分别用BM3D和FFDNet使用不同的输入噪声标准差去噪(标准差依次设为5、10、15、25、35、50),其中绿色框代表输入噪声的标准差与真实噪声标准差一致。可以观察到,当输入噪声的标准差与真实噪声的标准差一致时,去噪效果最好。当输入噪声标准差低于真实值时,去噪结果包含可察觉的噪声;而当输入噪声标准差高于真实值时,去噪结果仍能保持较好的结果,虽然也平滑了部分低对比度的纹理。因此,非盲去噪方法对低估误差比较敏感,而对高估的误差比较鲁棒。正是因为这个特性,BM3D可以通过设置相对较高的输入噪声标准差得到满意的真实图像去噪效果。

   为了消除这种非对称敏感性,文章设计了非对称损失函数用于噪声估计。给定像素i的估计噪声水平 σ ^ ( y i ) \hatσ(y_i) σ^(yi)和真实值 σ ( y i ) σ(y_i) σ(yi),当 σ ^ ( y i ) < σ ( y i ) \hat{\sigma}(y_i)<\sigma(y_i) σ^(yi)<σ(yi)时,应该对其MSE引入更多的惩罚,因此非对称损失 L a s y m m \mathcal{L}_{asymm} Lasymm为: L a s y m m = ∑ i ∣ α − I ( σ ^ ( y i ) − σ ( y i ) ) < 0 ∣ ⋅ ( σ ^ ( y i ) − σ ( y i ) ) 2 \mathcal{L}_{asymm}=\sum_i{|\alpha-\mathbb{I}_{(\hat{\sigma}(y_i)-\sigma(y_i))<0}}|\cdot(\hat{\sigma}(y_i)-\sigma(y_i))^2 Lasymm=iαI(σ^(yi)σ(yi))<0(σ^(yi)σ(yi))2
当e<0时, I e = 1 \mathbb{I}_e=1 Ie=1,否则为0。通过设定0<α<0.5,我们可以对低估误差引入更多的惩罚。

2.全变分损失

   全变分模型本身是去噪任务中依靠梯度下降流对图像进行平滑的模型,motivation是要在图像内部尽可能对图像进行平滑(相邻像素的差值较小),而在图像边缘(图像轮廓)尽可能不去平滑。这里是用在噪声level map上,个人认为是通过使临近区域保持相近的噪声程度的方法来提升增强结果质量。因为反过来想,临近区域存在噪声程度突变的话,这种突兀会带来明显的视觉效果降低。所以我们引入全变分 ( T V ) (TV) (TV)正则项约束 σ ^ ( y ) \hatσ(y) σ^(y)的平滑性
L T V = ∥ ∇ h σ ^ ( y ) ∥ 2 2 + ∥ ∇ v σ ^ ( y ) ∥ 2 2 \mathcal{L}_{TV}=\|\nabla_h\hat{\sigma}(y)\|^2_2+\|\nabla_v\hat{\sigma}(y)\|^2_2 LTV=hσ^(y)22+vσ^(y)22

3.重建损失 L r e c \mathcal{L}_{rec} Lrec

   对于去噪网络输入的 x ^ \hat x x^,定义重建误差,即网络输出的去噪图像和真实无噪声图像的差距.
L r e c = ∥ x ^ − x ∥ 2 2 \mathcal{L}_{rec}=\|\hat{x}-x\|^2_2 Lrec=x^x22

训练

目前去噪数据集的建立主要分为以下三种方式:
1.从现有图像数据库获取高质量图像,然后做图像处理(如线性变化、亮度调整)并根据噪声模型添加人工合成噪声,生成噪声图像;这种方法比较简单省时,高质量图像可以直接从网上获取,但由于噪声是人工合成的,其与真实噪声图像有一定差异,使得在该数据集上训练的网络在真实噪声图像上的去噪效果受限;
2. 针对同一场景,拍摄低ISO图像作为ground truth,高ISO图像作为噪声图像,并调整曝光时间等相机参数使得两张图像亮度一致;这种方法只使用单张低ISO图像作为ground truth,难免会残留噪声,且与噪声图像可能存在亮度差异和不对齐的问题;
3. 对同一场景连续拍摄多张图像,然后做图像处理(如图像配准、异常图像剔除等),然后加权平均合成ground truth;这种方法需要拍摄大量图像,工作量比较大,且需要对图像进行严格对准,但一般得到的ground truth质量比较高;

   为了提高去噪网络的鲁棒性和泛化能力,常常需要将输入噪声图像的噪声水平估计也作为网络输入。而真实噪声图像的噪声水平估计往往存在一定误差,从这一方面考虑,合成噪声图像由于噪声模型已知,所以其噪声水平估计是准确的,有利于网络的在不同噪声水平上的泛化。CBDNet就考虑将真实噪声图像和合成噪声图像一起作为训练集,交替对网络进行训练以提升网络的性能。

   对于合成噪声图像,作为ground-truth的干净图像和噪声水平图是可用的,但噪声模型可能与真实噪声不太相符;而对于真实噪声图像,噪声是真实的,但仅仅可以获得接近无噪声的图像作为ground truth,而噪声水平图是未知的。另外,一般真实噪声图像的ground truth比较难以获取,而合成噪声图像可以比较方便的大规模合成。因此,在训练CBDNet的过程中,结合这两种类型的图像,提高网络的泛化能力。

   文章使用上述的噪声模型合成噪声图像。其使用了BSD500的400张图像,Waterloo的1600张图像,MIT-Adobe FiveK数据库的1600张图像作为训练数据。对于真实图像,使用RENOIR数据库的120张图像。

   为了提高网络的泛化能力,交替使用一批合成图像和一批真实图像进行训练。当时使用一批合成图像时,所有的损失函数都会被最小化以更新CBDNet;当使用一批真实图像时,由于真实噪声水平未知,仅仅在训练中使用 L r e c \mathcal{L}{rec} Lrec L T V \mathcal{L}_{TV} LTV



二、代码实现

本代码解析使用了 浙江大学 IDKiro 复现的代码,对他的辛苦工作表示感谢!
在原代码的基础上稍微进行了简化,因此效果应该略有不同。
原代码地址:https://github.com/IDKiro/CBDNet-pytorch

2.定义 CBDNet 网络

整体架构可以看到,网络由一个 全卷积网络FCN,和一个 UNet 组成

2.1 FCN部分

包括 5 次 conv 操作,使用 3x3 的卷积核,使用了1个像素的padding来保证尺寸一致,feature map 数量依次为:3 ==> 32 ==> 32 ==> 32 ==> 32 ==> 3

class FCN(nn.Module):
    def __init__(self):
        super(FCN, self).__init__()

        # 3 ==> 32 的输入卷积
        self.inc = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(inplace=True))
        
        # 32 ==> 32 的中间卷积
        self.conv = nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # 32 ==> 3 的输出卷积 
        self.outc = nn.Sequential(
            nn.Conv2d(32, 3, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # 第 1 次卷积
        conv1 = self.inc(x)
        # 第 2 次卷积
        conv2 = self.conv(conv1)
        # 第 3 次卷积
        conv3 = self.conv(conv2)
        # 第 4 次卷积
        conv4 = self.conv(conv3)
        # 第 5 次卷积
        conv5 = self.outc(conv4)
        return conv5

2.2UNet部分

如上面架构图所示,UNet 使用到了大量卷积,全部是 3x3 ,加了 1 个像素的padding保证尺寸,这里编写了一个 single_conv 类用于卷积操作,包括卷积 和 ReLU 函数。这个类创建需要两个参数:输入的通道数 in_ch,输出的通道数 out_ch。具体如下:

class single_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(single_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x

从网络架构图可以看到,网络有两次下采样,还有两次上采样。下采样使用 2x2 的均值 pooling 就可以了,如何处理上采样呢?

上采样采用的是反卷积,这里编写了一个 up 类。输入的通道数是 in_ch,输出的通道数是 in_ch//2。因为是反卷积,所以使用 nn.ConvTranspose2d 函数,卷积核大小为2。即原图像中的 1 个像素经过卷积会变成 2*2 的区域。同时,卷积的步长为 2,卷积结果紧密的拼接为一张大图。

同时,也有一些需要特殊考虑的地方,从图中可以看出,这里有一个特征融合的步骤,这时可能会产生一定的问题,那就是两个feature map 尺寸可能不一样,比如:之前尺寸是 7,下采样再上采样的话,尺寸变化为: 7 ==> 3 ==> 6 ,因为 2x2 pooling 的时候,不会考虑最边上的像素!

所以,在 forward 函数中,加入了一个 padding 操作,代码如下:

class up(nn.Module):
    def __init__(self, in_ch):
        super(up, self).__init__()
        self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)

    # forward 需要两个输入,x1 是需要上采样的小尺寸 feature map
    # x2 是以前的大尺寸 feature map,因为中间的 pooling 可能损失了边缘像素,
    # 所以上采样以后的 x1 可能会比 x2 尺寸小
    def forward(self, x1, x2):
        # x1 上采样
        x1 = self.up(x1)
        
        # 输入数据是四维的,第一个维度是样本数,剩下的三个维度是 CHW
        # 所以 Y 方向上的悄寸差别在 [2],  X 方向上的尺寸差别在 [3] 
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        # 给 x1 进行 padding 操作
        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        # 把 x2 加到反卷积后的 feature map
        x = x2 + x1
        return x

需要注意的是,输出层也写了一个类,输出部分是将 64 个 feature map,利用 1x1 的卷积变成 3 个 feature map。 教程里介绍 GoogLetNet,ResNet 的时候也有写,1x1 的卷积可以较好的起到降维作用。最后一层不使用激活函数。

class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

下面是 UNet 部分的完整代码,和架构图示完全一致,包括:

  • input conv : 6 ==> 64 ==> 64
  • down1 : 2x2 的均值 pooling
  • conv1 : 64 ==> 128 ==> 128 ==> 128
  • down2 : 2x2 的均值 pooling
  • conv2 : 128 ==> 256 ==> 256 ==> 256 ==> 256 ==> 256 ==> 256
  • up1 : conv2 反卷积,和 conv1 的结果相加,输入256,输出128
  • conv3 : 128 ==> 128 ==> 128 ==> 128
  • up2 : conv3 反卷积,和 input conv 的结果相加,输入128,输出64
  • conv4 : 64 ==> 64 ==> 64
  • output conv: 65 ==> 3,用1x1的卷积降维,得到降噪结果

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.inc = nn.Sequential(
            single_conv(6, 64),
            single_conv(64, 64))

        self.down1 = nn.AvgPool2d(2)
        self.conv1 = nn.Sequential(
            single_conv(64, 128),
            single_conv(128, 128),
            single_conv(128, 128))

        self.down2 = nn.AvgPool2d(2)
        self.conv2 = nn.Sequential(
            single_conv(128, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256))

        self.up1 = up(256)
        self.conv3 = nn.Sequential(
            single_conv(128, 128),
            single_conv(128, 128),
            single_conv(128, 128))

        self.up2 = up(128)
        self.conv4 = nn.Sequential(
            single_conv(64, 64),
            single_conv(64, 64))

        self.outc = outconv(64, 3)

    def forward(self, x):
        # input conv : 6 ==> 64 ==> 64
        inx = self.inc(x)

        # 均值 pooling, 然后 conv1 : 64 ==> 128 ==> 128 ==> 128
        down1 = self.down1(inx)
        conv1 = self.conv1(down1)

        # 均值 pooling,然后 conv2 : 128 ==> 256 ==> 256 ==> 256 ==> 256 ==> 256 ==> 256
        down2 = self.down2(conv1)
        conv2 = self.conv2(down2)

        # up1 : conv2 反卷积,和 conv1 的结果相加,输入256,输出128
        up1 = self.up1(conv2, conv1)
        # conv3 : 128 ==> 128 ==> 128 ==> 128
        conv3 = self.conv3(up1)

        # up2 : conv3 反卷积,和 input conv 的结果相加,输入128,输出64
        up2 = self.up2(conv3, inx)
        # conv4 : 64 ==> 64 ==> 64
        conv4 = self.conv4(up2)

        # output conv: 65 ==> 3,用1x1的卷积降维,得到降噪结果
        out = self.outc(conv4)
        return out

2.3 CBDNet 分析

下面是 CBDNet 整个网络的代码,先将数据输入 FCN,得到估计的噪声强度: noise_level,为 3 通道。然后将 3通道的原图像,和 noise_level 拼接在一起,作为 UNet 的输入。

UNet 经过一系列操作,得到 out ,这里的 out 被认为是噪声的 residual mapping,和 输入图像加在一起,输出最终的去噪图像。

可以看出,这里也采用了一 residual learning 的思想,认为 噪声的 residual mapping 学习起来更加容易。

class CBDNet(nn.Module):
    def __init__(self):
        super(CBDNet, self).__init__()
        self.fcn = FCN()
        self.unet = UNet()
    
    def forward(self, x):
        noise_level = self.fcn(x)
        concat_img = torch.cat([x, noise_level], dim=1)
        out = self.unet(concat_img) + x
        return noise_level, out

2.4损失函数

class fixed_loss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, out_image, gt_image, est_noise, gt_noise, if_asym):
        # 分别得到图像的高度和宽度
        h_x = est_noise.size()[2]
        w_x = est_noise.size()[3]
        # 每个样本为 CHW ,把 H 方向第一行的数据去掉,统计一下一共多少元素
        count_h = self._tensor_size(est_noise[:, :, 1:, :])
        # 每个样本为 CHW ,把 W 方向第一列的数据去掉,统计一下一共多少元素
        count_w = self._tensor_size(est_noise[:, :, : ,1:])
        # H 方向,第一行去掉得后的矩阵,减去最后一行去掉后的矩阵,即下方像素减去上方像素,平方,然后求和
        h_tv = torch.pow((est_noise[:, :, 1:, :] - est_noise[:, :, :h_x-1, :]), 2).sum()
        # W 方向,第一列去掉得后的矩阵,减去最后一列去掉后的矩阵,即右方像素减去左方像素,平方,然后求和
        w_tv = torch.pow((est_noise[:, :, :, 1:] - est_noise[:, :, :, :w_x-1]), 2).sum()
        # 求平均,得到平均每个像素上的 tvloss
        tvloss = h_tv / count_h + w_tv / count_w

        loss = torch.mean( \
                # 第三部分:重建损失
                torch.pow((out_image - gt_image), 2)) + \
                # 第一部分:对比损失
                if_asym * 0.5 * torch.mean(torch.mul(torch.abs(0.3 - F.relu(gt_noise - est_noise)), torch.pow(est_noise - gt_noise, 2))) + \
                # 第二部分:起平滑作用的 tvloss
                0.05 * tvloss
        return loss

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]

从上面的代码中可以看到,对比损失前系数为 0.5, alpha 取值为 0.3,tvloss 系数为 0.05,和论文里的默认参数一致。

【划重点】这里需要专门指出的是:

对于 gt_noise,只有在使用合成数据进行训练时才会用到;以前的图像去噪,大多在真实图像上加一个随机Gauss噪声,得到噪声图像,这时 gt_noise 是已知的,就能够输入。

这个教程里处理的是真实图像,因此没有 gt_noise,所以在训练时,gt_noise 一直是0。原来代码里专门有一部分是人工合成噪声来训练,为方便理解代码,暂时去掉了这部分。

下面是两个程序中要用到的两个小函数:

# 这个类用于存储 loss,观察结果时使用
# 每轮训练一张图像,就计算一下 loss 的均值存储在 self.avg 里,用于输出观察变化
# 同时,把当前 loss 的值存储在 self.val 里
class AverageMeter(object):
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.count += n
		self.avg = self.sum / self.count

# 图像矩阵由 hwc 转换为 chw ,这个就不多解释了
def hwc_to_chw(img):
    return np.transpose(img, axes=[2, 0, 1])
# 图像矩阵由 chw 转换为 hwc ,这个也不多解释
def chw_to_hwc(img):
    return np.transpose(img, axes=[1, 2, 0])

2.4 训练

# 训练的时候,输入图像尺寸都是 ps x ps 的
ps = 256

train_dir = './mini_denoise_dataset/train/'
train_fns = glob.glob(train_dir + 'Batch_*')

origin_imgs = [None] * len(train_fns)
noised_imgs = [None] * len(train_fns)

定义网络模型、优化器、损失函数

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 创建 模型 + 优化器 + 损失函数
model = CBDNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = fixed_loss()

开始训练:

cnt = 0
total_loss = AverageMeter()
# 设置为训练模式,即启用 BatchNormalization 和 Dropout
model.train()

for epoch in range(200):
    # 内存中清空图片
    for i in range(len(train_fns)):
        origin_imgs[i] = []
        noised_imgs[i] = []

    # 打乱训练图片的顺序
    for idx in np.random.permutation(len(train_fns)):    
        # 读入origin image;RGB通道反过来,然后归一化;转化为 float32类型
        origin_img = cv2.imread(glob.glob(train_fns[idx] + '/*Reference.bmp')[0])
        origin_img = origin_img[:,:,::-1] / 255.0
        origin_imgs[idx] = np.array(origin_img).astype('float32')

        # 读入noised image;因为一个文件夹里有2张噪声图,这里写了一个循环
        train_noised_list = glob.glob(train_fns[idx] + '/*Noisy.bmp')  
        for nidx in range(len(train_noised_list)):
            noised_img = cv2.imread(train_noised_list[nidx])
            noised_img = noised_img[:,:,::-1] / 255.0
            noised_img = np.array(noised_img).astype('float32')
            noised_imgs[idx].append(noised_img)

            H, W, C = origin_img.shape
            # 从图像中随机取 256x256 大小的块
            xx = np.random.randint(0, W-ps+1)
            yy = np.random.randint(0, H-ps+1)
            temp_origin_img = origin_imgs[idx][yy:yy+ps, xx:xx+ps, :]
            temp_noised_img = noised_imgs[idx][nidx][yy:yy+ps, xx:xx+ps, :]

            # 生成 0,1 随机数,随机做图像的左右、上下、通道翻转,增加训练样本的多样性
            if np.random.randint(0, 2) == 1:  # 左右翻转
                temp_origin_img = np.flip(temp_origin_img, axis=1)
                temp_noised_img = np.flip(temp_noised_img, axis=1)
            if np.random.randint(0, 2) == 1:  # 上下翻转
                temp_origin_img = np.flip(temp_origin_img, axis=0)
                temp_noised_img = np.flip(temp_noised_img, axis=0)
            if np.random.randint(0, 2) == 1:  # 通道翻转
                temp_origin_img = np.transpose(temp_origin_img, (1, 0, 2))
                temp_noised_img = np.transpose(temp_noised_img, (1, 0, 2))

            temp_noised_img_chw = hwc_to_chw(temp_noised_img)
            temp_origin_img_chw = hwc_to_chw(temp_origin_img)

            cnt += 1

            # 这里给输入数据增加一个维度,即原来是三维的,现在是四维的,方便CNN处理
            input_var  = torch.from_numpy(temp_noised_img_chw.copy()).type(torch.FloatTensor).unsqueeze(0).to(device)
            target_var = torch.from_numpy(temp_origin_img_chw.copy()).type(torch.FloatTensor).unsqueeze(0).to(device)

            # 噪声图像输入网络处理
            noise_level_est, output = model(input_var)
            # 计算损失
            loss = criterion(output, target_var, noise_level_est, 0, 0)
            total_loss.update(loss.item())
            # 常规操作: 梯度归零 + 反向传播 + 优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    print('[Epoch %d] [Img count %d] [Loss.val: %.4f] ([loss.avg: %.4f])\t' % (epoch, cnt, total_loss.val, total_loss.avg))



结果

一些基本参数:

test_dir = './mini_denoise_dataset/test/'
test_fns = glob.glob(test_dir + '*.bmp')

# 建立 result 目录,保存图片处理结果
result_dir = './result/'
if not os.path.exists( result_dir ):
    os.mkdir( result_dir )

for ind, test_img_path in enumerate(test_fns):
    model.eval()
    with torch.no_grad():
        print(test_img_path)
        # 读入图像,切换RGB通道并归一化,转化为 numpy float32格式
        noisy_img = cv2.imread(test_img_path)
        noisy_img = noisy_img[:,:,::-1] / 255.0
        noisy_img = np.array(noisy_img).astype('float32')

        # 转化为 chw 才符合 pytorch 网络的输入格式
        temp_noisy_img_chw = hwc_to_chw(noisy_img)
        # 图像放到 gpu 上
        input_var = torch.from_numpy(temp_noisy_img_chw.copy()).type(torch.FloatTensor).unsqueeze(0).to(device)
        # 输入模型得到结果
        _, output = model(input_var)

        # 输出结果转化为 numpy ,同时,把数据转到 0,1 之间(因为可能会有一些异常值)
        output_np = output.squeeze().cpu().detach().numpy()
        output_np = chw_to_hwc(np.clip(output_np, 0, 1))
        # 把噪声图像,和降噪后的图像拼接在一起,然后保存图像
        tempImg = np.concatenate((noisy_img, output_np), axis=1)*255.0
        
        Image.fromarray(np.uint8(tempImg)).save(fp=result_dir + 'test_%d.jpg'%(ind), format='JPEG')

论文笔记:CBDNet图像去噪网络_第3张图片

论文笔记:CBDNet图像去噪网络_第4张图片
论文笔记:CBDNet图像去噪网络_第5张图片
论文笔记:CBDNet图像去噪网络_第6张图片
可以看到,去噪的效果有了,但是感觉有些模糊。也许是训练数据用的不够,也许是训练的 epoch 还不够多 ~~~

你可能感兴趣的:(论文,去噪,机器学习,计算机视觉,卷积,人工智能)