今天带来一篇代码解读的文章,是2018年BMVC上的一篇暗光增强文章。个人觉得网络比较轻量并且能够取得还不错的效果。废话不多说,直接贴传送门:
文章地址:http://arxiv.org/abs/1808.04560
源码地址:https://github.com/weichen582/RetinexNet
文章基于Retinex理论,不懂的请戳这里:https://blog.csdn.net/lz0499/article/details/81154937
整体结构主要包括两个网络:DecomNet和RelightNet。DecomNet用于分解图片为反射分量和光照分量,RelightNet用于将光照分量修正,再与反射分量重建,得到修正后的图像。可参考下图:
其中,作者提到了在RelightNet中同时对反射分量进行去噪处理,但在代码中我没有明确看到这步操作,有知道的小伙伴可以评论区留言。
先来看DecomNet的网络构建部分。整体就是全卷积网络,具体看我代码注释。
def DecomNet(input_im, layer_num, channel=64, kernel_size=3): #分解网络
input_max = tf.reduce_max(input_im, axis=3, keepdims=True)
input_im = concat([input_max, input_im]) #选取RGB三通道中的最大值(亮度)进行堆叠,变成4通道,与最后一层卷积'recon_layer'相呼应
with tf.variable_scope('DecomNet', reuse=tf.AUTO_REUSE): #这里面创建的所有tensor都带有'DecomNet'
conv = tf.layers.conv2d(input_im, channel, kernel_size * 3, padding='same', activation=None, name="shallow_feature_extraction")
for idx in range(layer_num): #进行5次带relu激活的卷积
conv = tf.layers.conv2d(conv, channel, kernel_size, padding='same', activation=tf.nn.relu, name='activated_layer_%d' % idx)
conv = tf.layers.conv2d(conv, 4, kernel_size, padding='same', activation=None, name='recon_layer') # reconv到4通道以便分解
'''将卷积结果分解成反射分量和光照分量'''
R = tf.sigmoid(conv[:,:,:,0:3]) #反射分量(仅由物体本身决定),反应颜色一致性,需要三通道描述
L = tf.sigmoid(conv[:,:,:,3:4]) #光照分量,反应光照信息,一通道即可描述(相当于亮度图)
return R, L
然后是RelightNet的网络构建部分。
def RelightNet(input_L, input_R, channel=64, kernel_size=3): # 恢复(调整)网络
input_im = concat([input_R, input_L])
with tf.variable_scope('RelightNet'):
'''3次下采样'''
conv0 = tf.layers.conv2d(input_im, channel, kernel_size, padding='same', activation=None)
conv1 = tf.layers.conv2d(conv0, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)
conv2 = tf.layers.conv2d(conv1, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)
conv3 = tf.layers.conv2d(conv2, channel, kernel_size, strides=2, padding='same', activation=tf.nn.relu)
'''3次上采样 最近邻插值'''
up1 = tf.image.resize_nearest_neighbor(conv3, (tf.shape(conv2)[1], tf.shape(conv2)[2]))
deconv1 = tf.layers.conv2d(up1, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv2
up2 = tf.image.resize_nearest_neighbor(deconv1, (tf.shape(conv1)[1], tf.shape(conv1)[2]))
deconv2= tf.layers.conv2d(up2, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv1
up3 = tf.image.resize_nearest_neighbor(deconv2, (tf.shape(conv0)[1], tf.shape(conv0)[2]))
deconv3 = tf.layers.conv2d(up3, channel, kernel_size, padding='same', activation=tf.nn.relu) + conv0
'''多尺度特征融合,在不同尺度上对光照分量进行恢复'''
deconv1_resize = tf.image.resize_nearest_neighbor(deconv1, (tf.shape(deconv3)[1], tf.shape(deconv3)[2]))
deconv2_resize = tf.image.resize_nearest_neighbor(deconv2, (tf.shape(deconv3)[1], tf.shape(deconv3)[2]))
feature_gather = concat([deconv1_resize, deconv2_resize, deconv3])
feature_fusion = tf.layers.conv2d(feature_gather, channel, 1, padding='same', activation=None)
output = tf.layers.conv2d(feature_fusion, filters=1, kernel_size=3, padding='same', activation=None)
return output #返回单通道图像,即修正后的光照分量
接下来重点看损失函数的构建部分。首先看DecomNet的损失部分:
self.input_low = tf.placeholder(tf.float32, [None, None, None, 3], name='input_low') #占位符定义,喂数据
self.input_high = tf.placeholder(tf.float32, [None, None, None, 3], name='input_high')
[R_low, I_low] = DecomNet(self.input_low, layer_num=self.DecomNet_layer_num)
[R_high, I_high] = DecomNet(self.input_high, layer_num=self.DecomNet_layer_num)
I_low_3 = concat([I_low, I_low, I_low])
I_high_3 = concat([I_high, I_high, I_high])
I_delta_3 = concat([I_delta, I_delta, I_delta])
self.recon_loss_low = tf.reduce_mean(tf.abs(R_low * I_low_3 - self.input_low)) #分解出的两个图相乘 与 原来的图应该一样(Retinex理论)保证Decom的正确性
self.recon_loss_high = tf.reduce_mean(tf.abs(R_high * I_high_3 - self.input_high)) #同上
self.recon_loss_mutal_low = tf.reduce_mean(tf.abs(R_high * I_low_3 - self.input_low)) #保证R_high==R_low(反射分量一致性)
self.recon_loss_mutal_high = tf.reduce_mean(tf.abs(R_low * I_high_3 - self.input_high)) #同上
self.equal_R_loss = tf.reduce_mean(tf.abs(R_low - R_high)) #同上
self.Ismooth_loss_low = self.smooth(I_low, R_low) # 平滑光照图的同时保留边缘信息,适应原图特征(用梯度实现)
self.Ismooth_loss_high = self.smooth(I_high, R_high)
self.loss_Decom = self.recon_loss_low + self.recon_loss_high + 0.001 * self.recon_loss_mutal_low + 0.001 * self.recon_loss_mutal_high + 0.1 * self.Ismooth_loss_low + 0.1 * self.Ismooth_loss_high + 0.01 * self.equal_R_loss
可以看到,作者设计了挺多的函数,如同我注释中所说的。其中,平滑约束是我觉得最巧妙的地方。因为光照分量一般来说是低频部分,相对而言应该是平滑的,而反射分量反应物体特征,应该是细节丰富的。而且在图像边缘区,即便是光照分量也不能太过平滑,否则就类似高斯平滑失去特征了。有了平滑损失就可以自适应调节了。
其中,平滑函数,梯度函数定义如下:
def gradient(self, input_tensor, direction):
self.smooth_kernel_x = tf.reshape(tf.constant([[0, 0], [-1, 1]], tf.float32), [2, 2, 1, 1])
self.smooth_kernel_y = tf.transpose(self.smooth_kernel_x, [1, 0, 2, 3])
if direction == "x":
kernel = self.smooth_kernel_x
elif direction == "y":
kernel = self.smooth_kernel_y
return tf.abs(tf.nn.conv2d(input_tensor, kernel, strides=[1, 1, 1, 1], padding='SAME'))
def ave_gradient(self, input_tensor, direction):
return tf.layers.average_pooling2d(self.gradient(input_tensor, direction), pool_size=3, strides=1, padding='SAME')
def smooth(self, input_I, input_R): # 给光照图I的梯度赋权值,实现自适应调节作用。反射图R梯度越小,赋予的权值越大,使光照图梯度减小,变得平滑
input_R = tf.image.rgb_to_grayscale(input_R)
return tf.reduce_mean(self.gradient(input_I, "x") * tf.exp(-10 * self.ave_gradient(input_R, "x")) + self.gradient(input_I, "y") * tf.exp(-10 * self.ave_gradient(input_R, "y")))
然后就是RelightNet的损失部分了:
self.relight_loss = tf.reduce_mean(tf.abs(R_low * I_delta_3 - self.input_high)) #保证恢复图像的正确(与低质量图反射分量重建后 接近 高质量图)
self.Ismooth_loss_delta = self.smooth(I_delta, R_low)
self.loss_Relight = self.relight_loss + 3 * self.Ismooth_loss_delta
这部分比较简单,就一个重建损失加上平滑损失。
下面是数据增强部分,主要就是旋转操作加翻转操作。
def data_augmentation(image, mode):
if mode == 0:
# original
return image
elif mode == 1:
# flip up and down
return np.flipud(image)
elif mode == 2:
# rotate counterwise 90 degree
return np.rot90(image)
elif mode == 3:
# rotate 90 degree and flip up and down
image = np.rot90(image)
return np.flipud(image)
elif mode == 4:
# rotate 180 degree
return np.rot90(image, k=2)
elif mode == 5:
# rotate 180 degree and flip
image = np.rot90(image, k=2)
return np.flipud(image)
elif mode == 6:
# rotate 270 degree
return np.rot90(image, k=3)
elif mode == 7:
# rotate 270 degree and flip
image = np.rot90(image, k=3)
return np.flipud(image)
还有一点,代码是将整个数据集里的图片全部读入内存再处理的,这对于个人电脑来说有点不现实,所以最好写个生成器。这里我就不贴代码了,有需要的可以私聊我。
最后,谈谈缺点吧,就是处理后的图片色彩有点失真。主要还是因为Decom-Net对低光照/正常光照图像分解出来的反射分量无法做到完全一致吧。
但总体来说,还是很棒的了。