代码解读——Retinex低光照图像增强(Deep Retinex Decomposition for Low-Light Enhancement)

今天带来一篇代码解读的文章,是2018年BMVC上的一篇暗光增强文章。个人觉得网络比较轻量并且能够取得还不错的效果。废话不多说,直接贴传送门:

文章地址:http://arxiv.org/abs/1808.04560

源码地址:https://github.com/weichen582/RetinexNet

文章基于Retinex理论,不懂的请戳这里:https://blog.csdn.net/lz0499/article/details/81154937

整体结构主要包括两个网络:DecomNet和RelightNet。DecomNet用于分解图片为反射分量和光照分量,RelightNet用于将光照分量修正,再与反射分量重建,得到修正后的图像。可参考下图:

代码解读——Retinex低光照图像增强(Deep Retinex Decomposition for Low-Light Enhancement)_第1张图片

其中,作者提到了在RelightNet中同时对反射分量进行去噪处理,但在代码中我没有明确看到这步操作,有知道的小伙伴可以评论区留言。

先来看DecomNet的网络构建部分。整体就是全卷积网络,具体看我代码注释。

def DecomNet(input_im, layer_num, channel=64, kernel_size=3):       #分解网络
    input_max = tf.reduce_max(input_im, axis=3, keepdims=True)
    input_im = concat([input_max, input_im])                          #选取RGB三通道中的最大值(亮度)进行堆叠,变成4通道,与最后一层卷积'recon_layer'相呼应
    with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE):         #这里面创建的所有tensor都带有'DecomNet'
        conv = tf.layers.conv2d(input_im, channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction")
        for idx in range(layer_num):                                 #进行5次带relu激活的卷积
            conv = tf.layers.conv2d(conv, channel, kernel_size, padding='same', activation=tf.nn.relu, name='activated_layer_%d' % idx)
        conv = tf.layers.conv2d(conv, 4, kernel_size, padding='same', activation=None, name='recon_layer')      # reconv到4通道以便分解

    '''将卷积结果分解成反射分量和光照分量'''
    R = tf.sigmoid(conv[:,:,:,0:3])                       #反射分量(仅由物体本身决定),反应颜色一致性,需要三通道描述
    L = tf.sigmoid(conv[:,:,:,3:4])                       #光照分量,反应光照信息,一通道即可描述(相当于亮度图)

    return R, L

然后是RelightNet的网络构建部分。

def RelightNet(input_L, input_R, channel=64, kernel_size=3):    # 恢复(调整)网络
    input_im = concat([input_R, input_L])
    with tf.variable_scope('RelightNet'):
        '''3次下采样'''
        conv0 = tf.layers.conv2d(input_im, channel, kernel_size, padding='same', activation=None)
        conv1 = tf.layers.conv2d(conv0, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)
        conv2 = tf.layers.conv2d(conv1, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)
        conv3 = tf.layers.conv2d(conv2, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)

        '''3次上采样   最近邻插值'''
        up1 = tf.image.resize_nearest_neighbor(conv3, (tf.shape(conv2)[1], tf.shape(conv2)[2]))
        deconv1 = tf.layers.conv2d(up1, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv2
        up2 = tf.image.resize_nearest_neighbor(deconv1, (tf.shape(conv1)[1], tf.shape(conv1)[2]))
        deconv2= tf.layers.conv2d(up2, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv1
        up3 = tf.image.resize_nearest_neighbor(deconv2, (tf.shape(conv0)[1], tf.shape(conv0)[2]))
        deconv3 = tf.layers.conv2d(up3, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv0

        '''多尺度特征融合,在不同尺度上对光照分量进行恢复'''
        deconv1_resize = tf.image.resize_nearest_neighbor(deconv1, (tf.shape(deconv3)[1], tf.shape(deconv3)[2]))
        deconv2_resize = tf.image.resize_nearest_neighbor(deconv2, (tf.shape(deconv3)[1], tf.shape(deconv3)[2]))
        feature_gather = concat([deconv1_resize, deconv2_resize, deconv3])
        feature_fusion = tf.layers.conv2d(feature_gather, channel, 1, padding='same', activation=None)
        output = tf.layers.conv2d(feature_fusion, filters=1, kernel_size=3, padding='same', activation=None)
    return output                         #返回单通道图像,即修正后的光照分量

接下来重点看损失函数的构建部分。首先看DecomNet的损失部分:

self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low')      #占位符定义,喂数据
self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')

[R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num)
[R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)

I_low_3 = concat([I_low, I_low, I_low])
I_high_3 = concat([I_high, I_high, I_high])
I_delta_3 = concat([I_delta, I_delta, I_delta])

self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 -  self.input_low))      #分解出的两个图相乘 与 原来的图应该一样(Retinex理论)保证Decom的正确性
self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high))   #同上
self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low)) #保证R_high==R_low(反射分量一致性)
self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high)) #同上
self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high))                              #同上

self.Ismooth_loss_low = self.smooth(I_low, R_low)             # 平滑光照图的同时保留边缘信息,适应原图特征(用梯度实现)
self.Ismooth_loss_high = self.smooth(I_high, R_high)

self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss

可以看到,作者设计了挺多的函数,如同我注释中所说的。其中,平滑约束是我觉得最巧妙的地方。因为光照分量一般来说是低频部分,相对而言应该是平滑的,而反射分量反应物体特征,应该是细节丰富的。而且在图像边缘区,即便是光照分量也不能太过平滑,否则就类似高斯平滑失去特征了。有了平滑损失就可以自适应调节了。

其中,平滑函数,梯度函数定义如下:

def gradient(self, input_tensor, direction):
    self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])
    self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])

    if direction == "x":
       kernel = self.smooth_kernel_x
    elif direction == "y":
        kernel = self.smooth_kernel_y
    return tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))

def ave_gradient(self, input_tensor, direction):
    return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')

def smooth(self, input_I, input_R):     #   给光照图I的梯度赋权值,实现自适应调节作用。反射图R梯度越小,赋予的权值越大,使光照图梯度减小,变得平滑
    input_R = tf.image.rgb_to_grayscale(input_R)
    return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))

然后就是RelightNet的损失部分了:

self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high))       #保证恢复图像的正确(与低质量图反射分量重建后 接近 高质量图)
self.Ismooth_loss_delta = self.smooth(I_delta, R_low)

self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta

这部分比较简单,就一个重建损失加上平滑损失。

下面是数据增强部分,主要就是旋转操作加翻转操作。

def data_augmentation(image, mode):
    if mode == 0:
        # original
        return image
    elif mode == 1:
        # flip up and down
        return np.flipud(image)
    elif mode == 2:
        # rotate counterwise 90 degree
        return np.rot90(image)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        image = np.rot90(image)
        return np.flipud(image)
    elif mode == 4:
        # rotate 180 degree
        return np.rot90(image, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        image = np.rot90(image, k=2)
        return np.flipud(image)
    elif mode == 6:
        # rotate 270 degree
        return np.rot90(image, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        image = np.rot90(image, k=3)
        return np.flipud(image)

还有一点,代码是将整个数据集里的图片全部读入内存再处理的,这对于个人电脑来说有点不现实,所以最好写个生成器。这里我就不贴代码了,有需要的可以私聊我。

最后,谈谈缺点吧,就是处理后的图片色彩有点失真。主要还是因为Decom-Net对低光照/正常光照图像分解出来的反射分量无法做到完全一致吧。

但总体来说,还是很棒的了。

你可能感兴趣的:(代码解读——Retinex低光照图像增强(Deep Retinex Decomposition for Low-Light Enhancement))