论文地址:https://arxiv.org/abs/2001.06826
亮点:不需要成对训练数据;仅训练每个像素的增强高阶方程参数;网络损失函数考虑了图像空间一致性,曝光,颜色一致性和亮度
在这篇论文中,我们呈现了一种新的深度学习方法,零参考深度曲线估计,来进行微光图像增强。它可以在各种各样的灯光条件包括不均匀和弱光情况进行处理。不同于执行图像到图像的映射,我们把任务重新设定为一个特定图像曲线估计问题。特别地,提出的这种方法把一个人微光图像作为输入,并把产生的高阶曲线作为它的输出。然后这些曲线被用作对输入的变化范围的像素级调整,从而获得一个增强的图像。曲线估计是精心制定的以便于它保持图像增强的范围并保留相邻像素的对比度。重要的是,它是可微的,因此,我们可以通过一个深度卷积神经网络来了解曲线的可调参数。所提出的网络是轻量级的,它可以迭代地应用于近似高阶曲线,以获得更稳健和更精确的动态范围调整。
一个特别的优势是我们的深度学习给予的方法是零参考。与现有的基于CNN和GAN的方法一样,它在训练过程中不需要任何成对的或者甚至不成对的数据。这是通过一组特别设计的非参考损失函数实现的,这些函数包括空间一致性损失、曝光控制损失、颜色恒定性损失和光照平滑度损失,所有这些都考虑了光增强的多因素。我们发现,即使使用零参考训练,zero-DCE仍然可以与其他需要成对或不成对数据进行训练的方法相比具有竞争力。图1示出了增强包含非均匀照明的微光图像的示例。与最新的方法相比,零DCE在保持固有颜色和细节的同时使图像变亮。相比之下,基于CNN的方法[28]和基于GAN的EnightGan都产生了低于(面部)和过度(内阁)的增强。
典型微光图像的视觉比较。所提出的零DCE在亮度、颜色、对比度和自然度方面都达到了令人满意的效果,而现有的方法要么无法处理极端的背光,要么产生颜色伪影。与其他基于深度学习的方法相比,我们的方法在没有任何参考图像的情况下进行训练。
我们的贡献总结如下。
1) 我们提出了第一个弱光增强网络,它独立于成对和不成对的训练数据,从而避免了过度拟合的风险。结果表明,该方法能很好地推广到各种照明条件下。
2) 我们设计了一个特定于图像的曲线,它能够通过迭代应用自身来近似像素级和高阶曲线。这种特定于图像的曲线可以在较宽的动态范围内有效地进行映射。
3) 我们展示了在没有参考图像的情况下,通过任务特定的非参考损失函数(间接评估增强质量)训练深度图像增强模型的潜力。
我们的零DCE方法在定性和定量度量方面都取代了最先进的性能。更重要的是,它能够改进高级视觉任务,例如人脸检测,而不会造成很高的计算负担。它能够实时处理图像(在GPU上,640×480×3大小的图像约为500fps),训练只需30分钟。
难点:
1.确定怎样的方程式子?
增强映射方程式如下:
该方程式需要满足三个条件:输出值【0,1】,避免计算过程中值溢出;简单且可微;保证图像相邻相邻像素的差异,其中阿尔法就是需要学习的参数,不但可以增强图像,也可以控制曝光
2. 如何迭代?
迭代方式如下:
原理类似于递归调用,上一次计算的输出作为本次的输入,不断迭代,可以理解为对输入图像不断进行这样的迭代,不断尝试,直到找到最优图像,也即loss最小。不过迭代多少次,本文从实验获取为8次。
3. 怎样与CNN结合?
网络架构如下:
对输入图像的每个信道分别做迭代操作,每次迭代操作的输出和输入图像map层再次结合作为下层输入。loss的计算实际就是计算最后增强的图像和原始输入图像之间的gap最大。
Lspa
在保持输入图像和增强图像之间相邻区域的差异方面的重要性。Lexp
无法恢复低光区域。Lcol
被丢弃时,会出现严重的颜色投射。应用曲线映射时,此变量忽略三个通道之间的关系。Ltva
会阻碍相邻区域之间的相关性,从而导致明显的伪影。import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.models.vgg import vgg16
import pytorch_colors as colors
import numpy as np
class L_color(nn.Module):
def __init__(self):
super(L_color, self).__init__()
def forward(self, x ):
b,c,h,w = x.shape
mean_rgb = torch.mean(x,[2,3],keepdim=True)
mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
Drg = torch.pow(mr-mg,2)
Drb = torch.pow(mr-mb,2)
Dgb = torch.pow(mb-mg,2)
k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)
return k
class L_spa(nn.Module):
def __init__(self):
super(L_spa, self).__init__()
# print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
self.pool = nn.AvgPool2d(4)
def forward(self, org , enhance ):
b,c,h,w = org.shape
org_mean = torch.mean(org,1,keepdim=True)
enhance_mean = torch.mean(enhance,1,keepdim=True)
org_pool = self.pool(org_mean)
enhance_pool = self.pool(enhance_mean)
weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)
D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)
D_left = torch.pow(D_org_letf - D_enhance_letf,2)
D_right = torch.pow(D_org_right - D_enhance_right,2)
D_up = torch.pow(D_org_up - D_enhance_up,2)
D_down = torch.pow(D_org_down - D_enhance_down,2)
E = (D_left + D_right + D_up +D_down)
# E = 25*(D_left + D_right + D_up +D_down)
return E
class L_exp(nn.Module):
def __init__(self,patch_size,mean_val):
super(L_exp, self).__init__()
# print(1)
self.pool = nn.AvgPool2d(patch_size)
self.mean_val = mean_val
def forward(self, x ):
b,c,h,w = x.shape
x = torch.mean(x,1,keepdim=True)
mean = self.pool(x)
d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
return d
class L_TV(nn.Module):
def __init__(self,TVLoss_weight=1):
super(L_TV,self).__init__()
self.TVLoss_weight = TVLoss_weight
def forward(self,x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = (x.size()[2]-1) * x.size()[3]
count_w = x.size()[2] * (x.size()[3] - 1)
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
class Sa_Loss(nn.Module):
def __init__(self):
super(Sa_Loss, self).__init__()
# print(1)
def forward(self, x ):
# self.grad = np.ones(x.shape,dtype=np.float32)
b,c,h,w = x.shape
# x_de = x.cpu().detach().numpy()
r,g,b = torch.split(x , 1, dim=1)
mean_rgb = torch.mean(x,[2,3],keepdim=True)
mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
Dr = r-mr
Dg = g-mg
Db = b-mb
k =torch.pow( torch.pow(Dr,2) + torch.pow(Db,2) + torch.pow(Dg,2),0.5)
# print(k)
k = torch.mean(k)
return k
class perception_loss(nn.Module):
def __init__(self):
super(perception_loss, self).__init__()
features = vgg16(pretrained=True).features
self.to_relu_1_2 = nn.Sequential()
self.to_relu_2_2 = nn.Sequential()
self.to_relu_3_3 = nn.Sequential()
self.to_relu_4_3 = nn.Sequential()
for x in range(4):
self.to_relu_1_2.add_module(str(x), features[x])
for x in range(4, 9):
self.to_relu_2_2.add_module(str(x), features[x])
for x in range(9, 16):
self.to_relu_3_3.add_module(str(x), features[x])
for x in range(16, 23):
self.to_relu_4_3.add_module(str(x), features[x])
# don't need the gradients, just want the features
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
h = self.to_relu_1_2(x)
h_relu_1_2 = h
h = self.to_relu_2_2(h)
h_relu_2_2 = h
h = self.to_relu_3_3(h)
h_relu_3_3 = h
h = self.to_relu_4_3(h)
h_relu_4_3 = h
# out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3)
return h_relu_4_3
我们评估了零DCE中参数的影响,包括DCE网络的深度和宽度以及迭代次数。在图5(b)中,仅使用三个卷积层,Zero DCE 3-32-8已经可以产生令人满意的结果,表明零参考学习的有效性。Zero DCE 7-32-8和Zero DCE 7-32-16通过自然曝光和适当的对比度产生最令人满意的视觉效果。通过将迭代次数减少到1,在零DCE 7-32-1上观察到性能的明显下降,如图5(d)所示。这是因为只有一次迭代的曲线调整能力有限。这表明我们的方法需要高阶曲线。基于Zero-dce7-32-8在效率和恢复性能之间的良好折衷,我们选择了Zero-dce7-32-8作为最终模型。
import torch
import torch.nn as nn
import math
import pytorch_colors as colors
import numpy as np
class enhance_net_nopool(nn.Module): #创建增强网络结构
def __init__(self):
super(enhance_net_nopool, self).__init__()
self.relu = nn.ReLU(inplace=True)
number_f = 32 #通道为32个
self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True)
self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True) #创建7个卷积层
self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) #最大池化层
self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) #用于2D数据的线性插值算法
def forward(self, x):
x1 = self.relu(self.e_conv1(x))
# p1 = self.maxpool(x1)
x2 = self.relu(self.e_conv2(x1))
# p2 = self.maxpool(x2)
x3 = self.relu(self.e_conv3(x2))
# p3 = self.maxpool(x3)
x4 = self.relu(self.e_conv4(x3))
x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1))) #在维度1上拼接x3,x4
# x5 = self.upsample(x5)
x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1) ##在维度1上进行划分,每大块包含3个小块
x = x + r1*(torch.pow(x,2)-x)
x = x + r2*(torch.pow(x,2)-x)
x = x + r3*(torch.pow(x,2)-x)
enhance_image_1 = x + r4*(torch.pow(x,2)-x)
x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
x = x + r6*(torch.pow(x,2)-x)
x = x + r7*(torch.pow(x,2)-x)
enhance_image = x + r8*(torch.pow(x,2)-x)
r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
return enhance_image_1,enhance_image,r
import os
import sys
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
import glob
import random
import cv2
random.seed(1143)
def populate_train_list(lowlight_images_path): #获取训练列表(微光图像路径)
image_list_lowlight = glob.glob(lowlight_images_path + "*.jpg")#提取微光图像
train_list = image_list_lowlight #将其作为训练列表
random.shuffle(train_list) #将列表的元素顺序打乱
return train_list #返回打乱顺序后的列表
class lowlight_loader(data.Dataset): #创建微光图像类(数据集)
def __init__(self, lowlight_images_path): #初始化
self.train_list = populate_train_list(lowlight_images_path) #将获取的训练列表存入self训练列表
self.size = 256
self.data_list = self.train_list #数据列表
print("Total training examples:", len(self.train_list)) #打印训练列表中元素的个数
def __getitem__(self, index): #实现对象迭代
data_lowlight_path = self.data_list[index] #索引数据列表
data_lowlight = Image.open(data_lowlight_path) #读取数据列表文件的位置
data_lowlight = data_lowlight.resize((self.size,self.size), Image.ANTIALIAS)
#调整图像大小为256*256,第二个参数为Image.ANTIALIAS:高质量
data_lowlight = (np.asarray(data_lowlight)/255.0) #把数据列表装换为数组
data_lowlight = torch.from_numpy(data_lowlight).float()#把数组转化内浮点型张量
return data_lowlight.permute(2,0,1)
def __len__(self):
return len(self.data_list) #返回数据列表长度
if __name__ == "__main__":
train_list = populate_train_list('Zero-DCE-master/Zero-DCE_code/data/train_data/')
print(train_list)
import torch
import torch.nn as nn
import torchvision
import torch.backends.cudnn as cudnn
import torch.optim
import os
import sys
import argparse
import time
import dataloader
import model
import Myloss
import numpy as np
from torchvision import transforms
def weights_init(m): #初始化权重
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def train(config): #训练()
os.environ['CUDA_VISIBLE_DEVICES']='0'
DCE_net = model.enhance_net_nopool().cuda() #深度曲线估计网络
DCE_net.apply(weights_init)
if config.load_pretrain == True:
DCE_net.load_state_dict(torch.load(config.pretrain_dir)) #load_pretrain == False
train_dataset = dataloader.lowlight_loader(config.lowlight_images_path) #256个数据集
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
L_color = Myloss.L_color()
L_spa = Myloss.L_spa()
L_exp = Myloss.L_exp(16,0.6)
L_TV = Myloss.L_TV()
optimizer = torch.optim.Adam(DCE_net.parameters(), lr=config.lr, weight_decay=config.weight_decay)
DCE_net.train()
for epoch in range(config.num_epochs):
for iteration, img_lowlight in enumerate(train_loader):
img_lowlight = img_lowlight.cuda()
enhanced_image_1,enhanced_image,A = DCE_net(img_lowlight)
Loss_TV = 200*L_TV(A)
loss_spa = torch.mean(L_spa(enhanced_image, img_lowlight))
loss_col = 5*torch.mean(L_color(enhanced_image))
loss_exp = 10*torch.mean(L_exp(enhanced_image))
# best_loss
loss = Loss_TV + loss_spa + loss_col + loss_exp
#
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(DCE_net.parameters(),config.grad_clip_norm) #梯度裁剪
optimizer.step()
if ((iteration+1) % config.display_iter) == 0:
print("Loss at iteration", iteration+1, ":", loss.item())
if ((iteration+1) % config.snapshot_iter) == 0:
torch.save(DCE_net.state_dict(), config.snapshots_folder + "Epoch" + str(epoch) + '.pth') #保存模型参数
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Input Parameters
parser.add_argument('--lowlight_images_path', type=str, default="Zero-DCE-master/Zero-DCE_code/data/train_data/")
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--weight_decay', type=float, default=0.0001)
parser.add_argument('--grad_clip_norm', type=float, default=0.1)
parser.add_argument('--num_epochs', type=int, default=200)
parser.add_argument('--train_batch_size', type=int, default=1)
parser.add_argument('--val_batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--display_iter', type=int, default=10)
parser.add_argument('--snapshot_iter', type=int, default=10)
parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
parser.add_argument('--load_pretrain', type=bool, default= False)
parser.add_argument('--pretrain_dir', type=str, default= "snapshots/Epoch99.pth")
config = parser.parse_args()
#
# print(config.lowlight_images_path,config.lr)
if not os.path.exists(config.snapshots_folder):
os.mkdir(config.snapshots_folder)
train(config)
训练结果:
Loss at iteration 10 : 1.191946029663086
Loss at iteration 20 : 0.8900860548019409
Loss at iteration 30 : 1.0086252689361572
Loss at iteration 40 : 0.6189944744110107
Loss at iteration 50 : 0.7224398255348206
Loss at iteration 60 : 0.965330958366394
Loss at iteration 70 : 0.44430315494537354
Loss at iteration 80 : 0.9305168986320496
Loss at iteration 90 : 0.4105320870876312
Loss at iteration 100 : 1.1877195835113525
Loss at iteration 110 : 0.9014407396316528
Loss at iteration 120 : 0.6933799982070923
Loss at iteration 130 : 0.39657557010650635
Loss at iteration 140 : 0.707856297492981
Loss at iteration 150 : 0.4551445245742798
Loss at iteration 160 : 0.4587801396846771
Loss at iteration 170 : 1.5810915231704712
Loss at iteration 180 : 0.6049085855484009
Loss at iteration 190 : 0.7140998840332031
Loss at iteration 200 : 0.6130244731903076
Loss at iteration 210 : 0.451404869556427
Loss at iteration 220 : 1.6068463325500488
Loss at iteration 230 : 1.2669832706451416
Loss at iteration 240 : 0.328972727060318
Loss at iteration 250 : 1.1017274856567383
Loss at iteration 260 : 0.8277110457420349