作者信息:
Shihao Zhang,华南理工
文章在分割网络结构中引入Guided Image Filtering
和 attention block
。Guided Image Filtering
用来引导图像特征,传递结构信息(边缘信息);attention block
用来去除噪声与无关背景的影响。
Kaiming He ECCV2010提出。edge-aware滤波算子。介绍Guided Image Filtering之前,先简单介绍下双边滤波。
参考资料
高斯滤波引入了空间信息,但缺点是,滤波后虽然减少了高频噪声,但相应地原有的边缘也会更加模糊.公式如下,滤波窗内距离中心点越远的像素,对滤波后的结果影响(权值)越小:
高斯滤波结果如下:
双边滤波在此高斯滤波的基础上还引入了像素信息,滤波窗内与中心点像素差异越大的点对最终结果的影响越小,这使得边缘信息不会被背景信息所掩盖。
双边滤波结果如下:
参考论文:Guided Image Filtering
Guided Image Filtering相比起双边滤波,对边缘信息保留得更好。若有input 图像 p,guided 图像 I,滤波窗w,滤波后图像q,Guided Image Filtering定义滤波窗内,q与I为线性关系,则有如下,a,b为线性系数
约束为q与p要尽可能相似,则定义如下损失函数。其中第二项为正则项。
使E最小,则得到a,b的最优解。其中 μ k \mu_k μk, σ k 2 \sigma_k^2 σk2为I在窗口内的均值和方差, p k ‾ \overline{p_k} pk为p在窗口内的均值
得到a,b后我们就可以由引导图像I
通过线性计算得到p
的引导滤波的滤波图。
此外,还要考虑到滤波窗有重叠,重叠部分输出取各个窗口输出的均值,则有
经推导,窗函数等价为:
实际算法流程如下:
滤波示例:
跑下题,最新2019CVPR oral的工作,改进点是把待滤波点设置在滤波窗的边上,而不是在窗中心。值得了解一下。
可以在已有滤波算法上直接改进。下图bcdef依次是均值,高斯,中值,双边,guided滤波以及对应的改进效果。
baseline网络M-Net是该作者2018年的TMI中提出的网络,本文改进为加入了下图中的AG模块
,
具体而言,AG引入了attention block
与 guided image filtering
,内部结构如下:
attention block
结构如下:
guided image filtering
部分:
论文的介绍如下,这部分对原始GUI算法改进点就是优化问题中考虑了attention map T,因此a,b的闭式解与原算法稍有不同。
在网络中,I对应AG模块接受的浅层特征,O对应AG模块接受的深层特征
代码地址
1、attention block
class GridAttentionBlock(nn.Module):
def __init__(self, in_channels):
super(GridAttentionBlock, self).__init__()
self.inter_channels = in_channels
self.in_channels = in_channels
self.gating_channels = in_channels
self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1)
self.phi = nn.Conv2d(in_channels=self.gating_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0, bias=True)
self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, g): #x,g 分别对应论文中I_l 和 O
input_size = x.size()
batch_size = input_size[0]
assert batch_size == g.size(0)
theta_x = self.theta(x)
theta_x_size = theta_x.size()
phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode='bilinear')
f = F.relu(theta_x + phi_g, inplace=True)
sigm_psi_f = F.sigmoid(self.psi(f))
return sigm_psi_f
2、AG模块
class FastGuidedFilter_attention(nn.Module):
def __init__(self, r, eps=1e-8):
super(FastGuidedFilter_attention, self).__init__()
self.r = r
self.eps = eps
self.boxfilter = BoxFilter(r)
self.epss = 1e-12
def forward(self, lr_x, lr_y, hr_x, l_a):
# 输入分别对应论文中 I_l, O, I, T
n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
n_lry, c_lry, h_lry, w_lry = lr_y.size()
n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
lr_x = lr_x.double()
lr_y = lr_y.double()
hr_x = hr_x.double()
l_a = l_a.double()
assert n_lrx == n_lry and n_lry == n_hrx
assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
assert h_lrx == h_lry and w_lrx == w_lry
assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
## N
N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))
# l_a = torch.abs(l_a)
l_a = torch.abs(l_a) + self.epss
t_all = torch.sum(l_a)
l_t = l_a / t_all ## norm
## mean_attention
mean_a = self.boxfilter(l_a) / N
## mean_a^2xy
mean_a2xy = self.boxfilter(l_a * l_a * lr_x * lr_y) / N
## mean_tax
mean_tax = self.boxfilter(l_t * l_a * lr_x) / N
## mean_ay
mean_ay = self.boxfilter(l_a * lr_y) / N
## mean_a^2x^2
mean_a2x2 = self.boxfilter(l_a * l_a * lr_x * lr_x) / N
## mean_ax
mean_ax = self.boxfilter(l_a * lr_x) / N
## A
temp = torch.abs(mean_a2x2 - N * mean_tax * mean_ax)
A = (mean_a2xy - N * mean_tax * mean_ay) / (temp + self.eps)
## b
b = (mean_ay - A * mean_ax) / (mean_a)
# --------------------------------
# Mean
# --------------------------------
A = self.boxfilter(A) / N
b = self.boxfilter(b) / N
## mean_A; mean_b
mean_A = F.upsample(A, (h_hrx, w_hrx), mode='bilinear')
mean_b = F.upsample(b, (h_hrx, w_hrx), mode='bilinear')
return (mean_A*hr_x+mean_b).float()
1、文章引入了传统cv方法:Guided Image Filtering。看来不能局限自己的一亩三分地,还是得广泛涉猎整个CV领域。
1、希望作者对使用Guided Image Filtering之后出来的特征进行可视化比较与分析。