[MICCAI2019] Attention Guided Network for Retinal Image Segmentation

作者信息:
Shihao Zhang,华南理工


文章在分割网络结构中引入Guided Image Filteringattention blockGuided Image Filtering用来引导图像特征,传递结构信息(边缘信息);attention block用来去除噪声与无关背景的影响。

Guided Image Filtering

Kaiming He ECCV2010提出。edge-aware滤波算子。介绍Guided Image Filtering之前,先简单介绍下双边滤波。

双边滤波

参考资料
高斯滤波引入了空间信息,但缺点是,滤波后虽然减少了高频噪声,但相应地原有的边缘也会更加模糊.公式如下,滤波窗内距离中心点越远的像素,对滤波后的结果影响(权值)越小:
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第1张图片
高斯滤波结果如下:
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第2张图片
双边滤波在此高斯滤波的基础上还引入了像素信息,滤波窗内与中心点像素差异越大的点对最终结果的影响越小,这使得边缘信息不会被背景信息所掩盖。
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第3张图片
双边滤波结果如下:
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第4张图片

Guided Image Filtering

参考论文:Guided Image Filtering

Guided Image Filtering相比起双边滤波,对边缘信息保留得更好。若有input 图像 p,guided 图像 I,滤波窗w,滤波后图像q,Guided Image Filtering定义滤波窗内,q与I为线性关系,则有如下,a,b为线性系数
在这里插入图片描述
约束为q与p要尽可能相似,则定义如下损失函数。其中第二项为正则项。
在这里插入图片描述
使E最小,则得到a,b的最优解。其中 μ k \mu_k μk, σ k 2 \sigma_k^2 σk2为I在窗口内的均值和方差, p k ‾ \overline{p_k} pk为p在窗口内的均值
在这里插入图片描述
得到a,b后我们就可以由引导图像I通过线性计算得到p的引导滤波的滤波图。
此外,还要考虑到滤波窗有重叠,重叠部分输出取各个窗口输出的均值,则有
在这里插入图片描述
经推导,窗函数等价为:
在这里插入图片描述
实际算法流程如下:
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第5张图片
滤波示例:
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第6张图片

Side Window Filtering

跑下题,最新2019CVPR oral的工作,改进点是把待滤波点设置在滤波窗的边上,而不是在窗中心。值得了解一下。
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第7张图片
可以在已有滤波算法上直接改进。下图bcdef依次是均值,高斯,中值,双边,guided滤波以及对应的改进效果。

Method

baseline网络M-Net是该作者2018年的TMI中提出的网络,本文改进为加入了下图中的AG模块
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第8张图片
具体而言,AG引入了attention blockguided image filtering,内部结构如下:
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第9张图片
attention block 结构如下:
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第10张图片
guided image filtering 部分:
论文的介绍如下,这部分对原始GUI算法改进点就是优化问题中考虑了attention map T,因此a,b的闭式解与原算法稍有不同。
在网络中,I对应AG模块接受的浅层特征,O对应AG模块接受的深层特征

[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第11张图片

Results

眼部血管分割以及视杯视盘分割结果如下:
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第12张图片
[MICCAI2019] Attention Guided Network for Retinal Image Segmentation_第13张图片

代码

代码地址

1、attention block

class GridAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(GridAttentionBlock, self).__init__()

        self.inter_channels = in_channels
        self.in_channels = in_channels
        self.gating_channels = in_channels

        self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1)

        self.phi = nn.Conv2d(in_channels=self.gating_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0, bias=True)
        self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, g): #x,g 分别对应论文中I_l 和 O
        input_size = x.size()
        batch_size = input_size[0]
        assert batch_size == g.size(0)

        theta_x = self.theta(x)
        theta_x_size = theta_x.size()

        phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode='bilinear')
        f = F.relu(theta_x + phi_g, inplace=True)

        sigm_psi_f = F.sigmoid(self.psi(f))

        return sigm_psi_f

2、AG模块

class FastGuidedFilter_attention(nn.Module):
    def __init__(self, r, eps=1e-8):
        super(FastGuidedFilter_attention, self).__init__()

        self.r = r
        self.eps = eps
        self.boxfilter = BoxFilter(r)
        self.epss = 1e-12

    def forward(self, lr_x, lr_y, hr_x, l_a):
     # 输入分别对应论文中 I_l, O, I, T 
        n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
        n_lry, c_lry, h_lry, w_lry = lr_y.size()
        n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()

        lr_x = lr_x.double()
        lr_y = lr_y.double()
        hr_x = hr_x.double()
        l_a = l_a.double()

        assert n_lrx == n_lry and n_lry == n_hrx
        assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
        assert h_lrx == h_lry and w_lrx == w_lry
        assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1

        ## N
        N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))

        # l_a = torch.abs(l_a)
        l_a = torch.abs(l_a) + self.epss
        t_all = torch.sum(l_a)
        l_t = l_a / t_all  ## norm

        ## mean_attention
        mean_a = self.boxfilter(l_a) / N
        ## mean_a^2xy
        mean_a2xy = self.boxfilter(l_a * l_a * lr_x * lr_y) / N
        ## mean_tax
        mean_tax = self.boxfilter(l_t * l_a * lr_x) / N
        ## mean_ay
        mean_ay = self.boxfilter(l_a * lr_y) / N
        ## mean_a^2x^2
        mean_a2x2 = self.boxfilter(l_a * l_a * lr_x * lr_x) / N
        ## mean_ax
        mean_ax = self.boxfilter(l_a * lr_x) / N

        ## A
        temp = torch.abs(mean_a2x2 - N * mean_tax * mean_ax)
        A = (mean_a2xy - N * mean_tax * mean_ay) / (temp + self.eps)
        ## b
        b = (mean_ay - A * mean_ax) / (mean_a)

        # --------------------------------
        # Mean
        # --------------------------------
        A = self.boxfilter(A) / N
        b = self.boxfilter(b) / N


        ## mean_A; mean_b
        mean_A = F.upsample(A, (h_hrx, w_hrx), mode='bilinear')
        mean_b = F.upsample(b, (h_hrx, w_hrx), mode='bilinear')

        return (mean_A*hr_x+mean_b).float()

我的笔记

1、文章引入了传统cv方法:Guided Image Filtering。看来不能局限自己的一亩三分地,还是得广泛涉猎整个CV领域。
1、希望作者对使用Guided Image Filtering之后出来的特征进行可视化比较与分析。

你可能感兴趣的:(MICCAI2019)