微光图像增强的零参考深度曲线估计

论文地址:https://arxiv.org/abs/2001.06826

文章目录

      • 思路总结
      • 损失函数代码实现(Myloss):
      • 模型代码实现(model)
      • dataloader
      • 程序训练(lowlight_train)

思路总结

亮点:不需要成对训练数据;仅训练每个像素的增强高阶方程参数;网络损失函数考虑了图像空间一致性,曝光,颜色一致性和亮度

在这篇论文中,我们呈现了一种新的深度学习方法,零参考深度曲线估计,来进行微光图像增强。它可以在各种各样的灯光条件包括不均匀和弱光情况进行处理。不同于执行图像到图像的映射,我们把任务重新设定为一个特定图像曲线估计问题。特别地,提出的这种方法把一个人微光图像作为输入,并把产生的高阶曲线作为它的输出。然后这些曲线被用作对输入的变化范围的像素级调整,从而获得一个增强的图像。曲线估计是精心制定的以便于它保持图像增强的范围并保留相邻像素的对比度。重要的是,它是可微的,因此,我们可以通过一个深度卷积神经网络来了解曲线的可调参数。所提出的网络是轻量级的,它可以迭代地应用于近似高阶曲线,以获得更稳健和更精确的动态范围调整。

一个特别的优势是我们的深度学习给予的方法是零参考。与现有的基于CNN和GAN的方法一样,它在训练过程中不需要任何成对的或者甚至不成对的数据。这是通过一组特别设计的非参考损失函数实现的,这些函数包括空间一致性损失、曝光控制损失、颜色恒定性损失和光照平滑度损失,所有这些都考虑了光增强的多因素。我们发现,即使使用零参考训练,zero-DCE仍然可以与其他需要成对或不成对数据进行训练的方法相比具有竞争力。图1示出了增强包含非均匀照明的微光图像的示例。与最新的方法相比,零DCE在保持固有颜色和细节的同时使图像变亮。相比之下,基于CNN的方法[28]和基于GAN的EnightGan都产生了低于(面部)和过度(内阁)的增强。
微光图像增强的零参考深度曲线估计_第1张图片
典型微光图像的视觉比较。所提出的零DCE在亮度、颜色、对比度和自然度方面都达到了令人满意的效果,而现有的方法要么无法处理极端的背光,要么产生颜色伪影。与其他基于深度学习的方法相比,我们的方法在没有任何参考图像的情况下进行训练。

我们的贡献总结如下。
1) 我们提出了第一个弱光增强网络,它独立于成对和不成对的训练数据,从而避免了过度拟合的风险。结果表明,该方法能很好地推广到各种照明条件下。

2) 我们设计了一个特定于图像的曲线,它能够通过迭代应用自身来近似像素级和高阶曲线。这种特定于图像的曲线可以在较宽的动态范围内有效地进行映射。

3) 我们展示了在没有参考图像的情况下,通过任务特定的非参考损失函数(间接评估增强质量)训练深度图像增强模型的潜力。

我们的零DCE方法在定性和定量度量方面都取代了最先进的性能。更重要的是,它能够改进高级视觉任务,例如人脸检测,而不会造成很高的计算负担。它能够实时处理图像(在GPU上,640×480×3大小的图像约为500fps),训练只需30分钟。

难点:
1.确定怎样的方程式子?
增强映射方程式如下:
在这里插入图片描述
该方程式需要满足三个条件:输出值【0,1】,避免计算过程中值溢出;简单且可微;保证图像相邻相邻像素的差异,其中阿尔法就是需要学习的参数,不但可以增强图像,也可以控制曝光
2. 如何迭代?
迭代方式如下:
在这里插入图片描述
原理类似于递归调用,上一次计算的输出作为本次的输入,不断迭代,可以理解为对输入图像不断进行这样的迭代,不断尝试,直到找到最优图像,也即loss最小。不过迭代多少次,本文从实验获取为8次。
3. 怎样与CNN结合?
网络架构如下:
微光图像增强的零参考深度曲线估计_第2张图片

对输入图像的每个信道分别做迭代操作,每次迭代操作的输出和输入图像map层再次结合作为下层输入。loss的计算实际就是计算最后增强的图像和原始输入图像之间的gap最大。

  • 损失函数如何设计?
    本文loss包含四种,分别为:
    空间一致性loss:
    在这里插入图片描述
    曝光控制loss:
    在这里插入图片描述
    色彩恒定性loss:
    在这里插入图片描述
    光照平滑度loss:
    微光图像增强的零参考深度曲线估计_第3张图片
    总loss:
    在这里插入图片描述
    我们在图4中给出了由各种损失组合训练的零度误差的结果。
    -没有空间一致性损失的结果与完整结果相比,对比度相对较低(例如,云区域)。这说明了Lspa在保持输入图像和增强图像之间相邻区域的差异方面的重要性。
    移除曝光控制损失Lexp无法恢复低光区域。
    颜色恒定性损失Lcol被丢弃时,会出现严重的颜色投射。应用曲线映射时,此变量忽略三个通道之间的关系。
    去除光照平滑度损失Ltva会阻碍相邻区域之间的相关性,从而导致明显的伪影。
    微光图像增强的零参考深度曲线估计_第4张图片

损失函数代码实现(Myloss):

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.models.vgg import vgg16
import pytorch_colors as colors
import numpy as np


class L_color(nn.Module):

    def __init__(self):
        super(L_color, self).__init__()

    def forward(self, x ):

        b,c,h,w = x.shape

        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = torch.pow(mr-mg,2)
        Drb = torch.pow(mr-mb,2)
        Dgb = torch.pow(mb-mg,2)
        k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)


        return k

			
class L_spa(nn.Module):

    def __init__(self):
        super(L_spa, self).__init__()
        # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
        self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
        self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
        self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
        self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
        self.pool = nn.AvgPool2d(4)

    def forward(self, org , enhance ):
        b,c,h,w = org.shape

        org_mean = torch.mean(org,1,keepdim=True)
        enhance_mean = torch.mean(enhance,1,keepdim=True)

        org_pool =  self.pool(org_mean)			
        enhance_pool = self.pool(enhance_mean)	

        weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
        E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)


        D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
        D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
        D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
        D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)

        D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
        D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
        D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
        D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)

        D_left = torch.pow(D_org_letf - D_enhance_letf,2)
        D_right = torch.pow(D_org_right - D_enhance_right,2)
        D_up = torch.pow(D_org_up - D_enhance_up,2)
        D_down = torch.pow(D_org_down - D_enhance_down,2)
        E = (D_left + D_right + D_up +D_down)
        # E = 25*(D_left + D_right + D_up +D_down)

        return E
class L_exp(nn.Module):

    def __init__(self,patch_size,mean_val):
        super(L_exp, self).__init__()
        # print(1)
        self.pool = nn.AvgPool2d(patch_size)
        self.mean_val = mean_val
    def forward(self, x ):

        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        mean = self.pool(x)

        d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
        return d
        
class L_TV(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(L_TV,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

class Sa_Loss(nn.Module):
    def __init__(self):
        super(Sa_Loss, self).__init__()
        # print(1)
    def forward(self, x ):
        # self.grad = np.ones(x.shape,dtype=np.float32)
        b,c,h,w = x.shape
        # x_de = x.cpu().detach().numpy()
        r,g,b = torch.split(x , 1, dim=1)
        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Dr = r-mr
        Dg = g-mg
        Db = b-mb
        k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
        # print(k)
        
        k = torch.mean(k)
        return k

class perception_loss(nn.Module):
    def __init__(self):
        super(perception_loss, self).__init__()
        features = vgg16(pretrained=True).features
        self.to_relu_1_2 = nn.Sequential() 
        self.to_relu_2_2 = nn.Sequential() 
        self.to_relu_3_3 = nn.Sequential()
        self.to_relu_4_3 = nn.Sequential()

        for x in range(4):
            self.to_relu_1_2.add_module(str(x), features[x])
        for x in range(4, 9):
            self.to_relu_2_2.add_module(str(x), features[x])
        for x in range(9, 16):
            self.to_relu_3_3.add_module(str(x), features[x])
        for x in range(16, 23):
            self.to_relu_4_3.add_module(str(x), features[x])
        
        # don't need the gradients, just want the features
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        h = self.to_relu_1_2(x)
        h_relu_1_2 = h
        h = self.to_relu_2_2(h)
        h_relu_2_2 = h
        h = self.to_relu_3_3(h)
        h_relu_3_3 = h
        h = self.to_relu_4_3(h)
        h_relu_4_3 = h
        # out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
        return h_relu_4_3

模型代码实现(model)

我们评估了零DCE中参数的影响,包括DCE网络的深度和宽度以及迭代次数。在图5(b)中,仅使用三个卷积层,Zero DCE 3-32-8已经可以产生令人满意的结果,表明零参考学习的有效性。Zero DCE 7-32-8和Zero DCE 7-32-16通过自然曝光和适当的对比度产生最令人满意的视觉效果。通过将迭代次数减少到1,在零DCE 7-32-1上观察到性能的明显下降,如图5(d)所示。这是因为只有一次迭代的曲线调整能力有限。这表明我们的方法需要高阶曲线。基于Zero-dce7-32-8在效率和恢复性能之间的良好折衷,我们选择了Zero-dce7-32-8作为最终模型。
微光图像增强的零参考深度曲线估计_第5张图片

import torch
import torch.nn as nn
import math
import pytorch_colors as colors
import numpy as np

class enhance_net_nopool(nn.Module):  #创建增强网络结构

	def __init__(self):
		super(enhance_net_nopool, self).__init__()

		self.relu = nn.ReLU(inplace=True) 

		number_f = 32  #通道为32个
		self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True) 
		self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) 
		self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) 
		self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True)       #创建7个卷积层

		self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)   #最大池化层
		self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)   #用于2D数据的线性插值算法  
		
	def forward(self, x):   

		x1 = self.relu(self.e_conv1(x))
		# p1 = self.maxpool(x1)
		x2 = self.relu(self.e_conv2(x1))
		# p2 = self.maxpool(x2)
		x3 = self.relu(self.e_conv3(x2))
		# p3 = self.maxpool(x3)
		x4 = self.relu(self.e_conv4(x3))

		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))  #在维度1上拼接x3,x4
		# x5 = self.upsample(x5)
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

		x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
		r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)  ##在维度1上进行划分,每大块包含3个小块


		x = x + r1*(torch.pow(x,2)-x)
		x = x + r2*(torch.pow(x,2)-x)
		x = x + r3*(torch.pow(x,2)-x)
		enhance_image_1 = x + r4*(torch.pow(x,2)-x)		
		x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)		
		x = x + r6*(torch.pow(x,2)-x)	
		x = x + r7*(torch.pow(x,2)-x)
		enhance_image = x + r8*(torch.pow(x,2)-x)
		r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
		return enhance_image_1,enhance_image,r

dataloader

import os
import sys

import torch
import torch.utils.data as data

import numpy as np
from PIL import Image
import glob
import random
import cv2

random.seed(1143)


def populate_train_list(lowlight_images_path): #获取训练列表(微光图像路径)

	image_list_lowlight = glob.glob(lowlight_images_path + "*.jpg")#提取微光图像
	train_list = image_list_lowlight #将其作为训练列表
	random.shuffle(train_list) #将列表的元素顺序打乱
	return train_list  #返回打乱顺序后的列表

class lowlight_loader(data.Dataset): #创建微光图像类(数据集)

	def __init__(self, lowlight_images_path): #初始化

		self.train_list = populate_train_list(lowlight_images_path) #将获取的训练列表存入self训练列表
		self.size = 256  
		
		self.data_list = self.train_list  #数据列表
		print("Total training examples:", len(self.train_list))  #打印训练列表中元素的个数

	def __getitem__(self, index):  #实现对象迭代

		data_lowlight_path = self.data_list[index] #索引数据列表
		
		data_lowlight = Image.open(data_lowlight_path) #读取数据列表文件的位置
		
		data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
		#调整图像大小为256*256,第二个参数为Image.ANTIALIAS:高质量

		data_lowlight = (np.asarray(data_lowlight)/255.0) #把数据列表装换为数组

		data_lowlight = torch.from_numpy(data_lowlight).float()#把数组转化内浮点型张量

		return data_lowlight.permute(2,0,1) 

	def __len__(self):

		return len(self.data_list)  #返回数据列表长度

if __name__ == "__main__":

	train_list = populate_train_list('Zero-DCE-master/Zero-DCE_code/data/train_data/')
	print(train_list)

程序训练(lowlight_train)

import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import os
import sys
import argparse
import time
import dataloader
import model
import Myloss
import numpy as np
from torchvision import transforms


def weights_init(m):   #初始化权重 
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def train(config):  #训练()

	os.environ['CUDA_VISIBLE_DEVICES']='0'

	DCE_net = model.enhance_net_nopool().cuda()  #深度曲线估计网络

	DCE_net.apply(weights_init)
	if config.load_pretrain == True:
	    DCE_net.load_state_dict(torch.load(config.pretrain_dir))   #load_pretrain == False

	train_dataset = dataloader.lowlight_loader(config.lowlight_images_path)	 #256个数据集	
	
	train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)


	L_color = Myloss.L_color()
	L_spa = Myloss.L_spa()

	L_exp = Myloss.L_exp(16,0.6)
	L_TV = Myloss.L_TV()
	

	optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
	
	DCE_net.train()

	for epoch in range(config.num_epochs):
		for iteration, img_lowlight in enumerate(train_loader):

			img_lowlight = img_lowlight.cuda()

			enhanced_image_1,enhanced_image,A  = DCE_net(img_lowlight)

			Loss_TV = 200*L_TV(A)
			
			loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight))

			loss_col = 5*torch.mean(L_color(enhanced_image))

			loss_exp = 10*torch.mean(L_exp(enhanced_image))
			
			# best_loss
			loss =  Loss_TV + loss_spa + loss_col + loss_exp
			#
			optimizer.zero_grad()
			loss.backward()
			torch.nn.utils.clip_grad_norm_(DCE_net.parameters(),config.grad_clip_norm) #梯度裁剪
			optimizer.step()

			if ((iteration+1) % config.display_iter) == 0:
				print("Loss at iteration", iteration+1, ":", loss.item())
			if ((iteration+1) % config.snapshot_iter) == 0:
				torch.save(DCE_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth') 	#保存模型参数	

if __name__ == "__main__":

	parser = argparse.ArgumentParser()

	# Input Parameters
	parser.add_argument('--lowlight_images_path', type=str, default="Zero-DCE-master/Zero-DCE_code/data/train_data/")
	parser.add_argument('--lr', type=float, default=0.0001)
	parser.add_argument('--weight_decay', type=float, default=0.0001)
	parser.add_argument('--grad_clip_norm', type=float, default=0.1)
	parser.add_argument('--num_epochs', type=int, default=200)
	parser.add_argument('--train_batch_size', type=int, default=1)
	parser.add_argument('--val_batch_size', type=int, default=4)
	parser.add_argument('--num_workers', type=int, default=4)
	parser.add_argument('--display_iter', type=int, default=10)
	parser.add_argument('--snapshot_iter', type=int, default=10)
	parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
	parser.add_argument('--load_pretrain', type=bool, default= False)
	parser.add_argument('--pretrain_dir', type=str, default= "snapshots/Epoch99.pth")

	config = parser.parse_args()
	#
	# print(config.lowlight_images_path,config.lr)

	if not os.path.exists(config.snapshots_folder):
		os.mkdir(config.snapshots_folder)
		
	train(config)

训练结果:

Loss at iteration 10 : 1.191946029663086
Loss at iteration 20 : 0.8900860548019409
Loss at iteration 30 : 1.0086252689361572
Loss at iteration 40 : 0.6189944744110107
Loss at iteration 50 : 0.7224398255348206
Loss at iteration 60 : 0.965330958366394
Loss at iteration 70 : 0.44430315494537354
Loss at iteration 80 : 0.9305168986320496
Loss at iteration 90 : 0.4105320870876312
Loss at iteration 100 : 1.1877195835113525
Loss at iteration 110 : 0.9014407396316528
Loss at iteration 120 : 0.6933799982070923
Loss at iteration 130 : 0.39657557010650635
Loss at iteration 140 : 0.707856297492981
Loss at iteration 150 : 0.4551445245742798
Loss at iteration 160 : 0.4587801396846771
Loss at iteration 170 : 1.5810915231704712
Loss at iteration 180 : 0.6049085855484009
Loss at iteration 190 : 0.7140998840332031
Loss at iteration 200 : 0.6130244731903076
Loss at iteration 210 : 0.451404869556427
Loss at iteration 220 : 1.6068463325500488
Loss at iteration 230 : 1.2669832706451416
Loss at iteration 240 : 0.328972727060318
Loss at iteration 250 : 1.1017274856567383
Loss at iteration 260 : 0.8277110457420349

你可能感兴趣的:(论文精读)