目录
背景
方法
总结
众所周知,目前深度学习去雾方法速度方面一直是个痛点,难以达到效果和速度兼容。基于此,本论文设计了快速的去雾网络,网络基于双边网格,能够通过提取一种双边网格的数据结构对原始输入图像进行变换和增强,不仅能够恢复出不错的效果,而且在速度上具有一定的优势。
上面是该去雾网络的主要架构图,核心采用双边网格,双边网格能够关注色彩突变的物体边界,并能够很好的关注于高频信息。可以看到非常地简洁,实际上,为了简洁而简洁可能有时候会对读者造成一定的误导。建议阅读源代码。
在该模型中的特征提取部分代码里使用的是Unet网络,图中则画的很简单,另外在RGB图像表示部分也不够清楚,实际上在图像输入之后也会有一个Unet进行特征提取,然后对提取后的特征分别进行三次卷积,以生成三个通道共九个通道的R, G, B channel。论文在速度上堪比AODNet,这点有待商榷。据论文所述,在实时性方面达到较好结果。
废话不多说,下面详细剖析一下论文做了哪些工作,论文的模型架构。
首先看结构图中的如下部分:
此部分为了生成双边网格,reduced resolution为输入图像的下采样结果,论文指出,通过低分辨率的图像进行双边网格生成影响不大,而且能够提升速度。Feature extraction 是由一个简单的Unet网络结构构成,该Unet由四次下采样和四次上采样构成,生成Features,代码如下
class UNet(nn.Module):
def __init__(self, n_channels, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.pre = nn.Conv2d(64, 3, 3, 1, 1)
self.re = nn.Sigmoid()
def forward(self, xs):
x1 = self.inc(xs)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.re(self.pre(x))
由Features到Affine bilateral grid经历了slice操作
class Slice(nn.Module):
def __init__(self):
super(Slice, self).__init__()
def forward(self, bilateral_grid, guidemap): #bilateral_grid:(64,256)|(-1, 12, 16, 16, 16) , guidemap:经卷积预处理后的单一颜色通道(256,256)
device = bilateral_grid.get_device()
N, _, H, W = guidemap.shape
hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW
if device >= 0:
hg = hg.to(device)
wg = wg.to(device)
hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H-1) # norm to [0,1] NxHxWx1
wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W-1) # norm to [0,1] NxHxWx1
hg, wg = hg*2-1, wg*2-1 ###########
guidemap = guidemap.permute(0, 2, 3, 1).contiguous()
guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3
coeff = F.grid_sample(bilateral_grid, guidemap_guide, align_corners=True) #
return coeff.squeeze(2)
下面是在原分辨率进行提取双边网格指导的特征图
原始输入图像经过Unet(图中未显示出),生成特征图(3个通道),图有误,分别送入三个卷积,生成3个三通道的特征图。
下面是将生成的Affine bilateral grid 指导调整 生成的3个三通道的特征图,生成3个3通道的结果特征图
代码如下
class ApplyCoeffs(nn.Module):
def __init__(self):
super(ApplyCoeffs, self).__init__()
self.degree = 3
def forward(self, coeff, full_res_input):
R = torch.sum(full_res_input * coeff[:, 0:3, :, :], dim=1, keepdim=True) + coeff[:, 3:4, :, :]
G = torch.sum(full_res_input * coeff[:, 4:7, :, :], dim=1, keepdim=True) + coeff[:, 7:8, :, :]
B = torch.sum(full_res_input * coeff[:, 8:11, :, :], dim=1, keepdim=True) + coeff[:, 11:12, :, :]
result = torch.cat([R, G, B], dim=1)
return result
最后是进行concat操作生成9通道特征图送入一个特征融合层融合为3通道,再经过简单地卷积调整,生成中间结果output,然后与输入进行相乘跳连,生成最终结果。
代码如下
output = torch.cat((output_r, output_g, output_b), dim=1) #生成九通道特征图
output = self.fusion(output) #生成三通道特征图
output = self.p(self.x_r_fusion(output) * x - output + 1) #x_r_fusion为简单的两层卷积
文章实质我想主要还是在前人的基础上进行修改用于去雾任务,文章思路不错,可以借鉴
更新代码下载地址:
链接:https://pan.baidu.com/s/1GuyZoRsRrO6aVofRjLw0DQ?pwd=ong4
提取码:ong4
本文为双边网格系列文章,系列目录如下:
双边网格学习、Bilateral Learning_Alocus_的博客-CSDN博客目录背景方法结论挖坑,双边网格学习。双边网格具有很多优良的特性,在图像恢复等方面还具有很多的价值可以挖掘,因此本系列我会把相关论文和代码以我的理解写成博客,留做记录,代码运行过程中如果有问题或者发现我也会写到博客里。背景双边网格学习有些相似于双边滤波,双边滤波(Bilateral Filter)是非线性滤波中的一种,结合图像的空间邻近度与像素值相似度。 在滤除噪声、平滑图像的同时,又保存边缘。一个负责计算空间邻近度 的权值,也就是常用的高斯 滤波器原理。而另一个负责 计算像素值相似度https://blog.csdn.net/Crystal_remember/article/details/123333844