EMANet:Expectation-Maximization Attention Networks for Semantic Segmentation论文解读和代码解读

官方项目地址:含论文和代码

来自北大才子 立夏之光的 ICCV Oral ,理论很漂亮。属于Non local方式

idea

由于理论方面涉及了机器学习算法 - EM算法,博主虽然学过EM,但时间久远有些记不起,这篇论文吧博主看了很久,依然没能理解其精髓,但是不影响我会使用它(哈哈)。言归正传。

在语义分割中,越来越多Non local的方法出现了,并且都取得了精度上的进步,说明Non local确实是有用的。但是这些方法都不能避免庞大的计算量,比如DANet,有很大的矩阵相乘。

EMANet的提出正是为解决Non local带来的计算量过于庞大。通过EM,E步学习一组attention maps, M步更新一组基,经过几次迭代之后,用基和maps 重构特征。 基的向量长度可以是个比较小的数值,我们可以理解为通过把原始特征降维,在低维的流形中建模像素之间的联系,这样的话,可以省略很多计算量。然后通过基和attention maps 重构出高维的、带有全局性的信息的特征。用这个特征在去做最后的分割。

network

EMANet:Expectation-Maximization Attention Networks for Semantic Segmentation论文解读和代码解读_第1张图片
本文不关注理论,只关注步骤,因为理论是在太难弄懂了。有关理论,可以去看第一作者的知乎专栏。

  • 经过一个CNN-based backjbone 得到特征X,经过一个 1x1的卷积降维,(因为ResNet最后的卷积输出的是2048通道的,太大了),降至512个通道。假设 X ∈ R N × C , N = H × W X \in R^{N \times C}, N = H \times W XRN×C,N=H×W。 H和W是特征图X的分辨率尺寸。
  • 初始化一个 μ ∈ R K × C \mu \in R^{K \times C} μRK×C作为基, K指的是有K个基。
  • E步: 得到attention maps, 记作Z。 Z = s o f t m a x ( λ X μ T ) ∈ R N × K Z = softmax(\lambda X \mu^T) \in R^{N \times K} Z=softmax(λXμT)RN×K, 即有K个maps, 每一个map的尺寸是H x W (N)
  • M步: 更新基 μ \mu μ,得到的maps Z, 先在第2个维度,即(dim=1,从0开始算)求和,做一个normlize。具体看代码解读部分。
  • 在每次M步之后,为了保证 μ \mu μ的学习是稳定的,选择L2Norm对 μ \mu μ做归一化。
  • E步和M步重复T次,T在论文中为3。
  • 训练中使用moving average更新 μ \mu μ,测试阶段跳过这一步。
  • 用得到的maps Z和基 μ \mu μ重构X,得到 X ~ ∈ R N × C \widetilde{X} \in R^{N\times C} X RN×C
  • 然后把 X ~ \widetilde{X} X reshape到CxHxW。送到接下来的segHead中。

基不是公共的。每一个样本经过迭代都会得到各自的基,因为不同图像的分布不一样。

Attention Maps

那么既然一组低维空间的基和一组maps(都是K个),能够学习到Non local的信息,那么我们自然该看看这些maps长得是什么样子吧。

我在网上找了一张图像,里面的类别都是VOC数据集出现的。
一共有64个maps,下面是一部分。
EMANet:Expectation-Maximization Attention Networks for Semantic Segmentation论文解读和代码解读_第2张图片

从上图中,可以发现模型确实在低维流形中学习到了Non loca的信息,还减小了计算量。而且通过降维学习(低秩)学习的基可以说是没有冗余的(正交基)。

code explain

下面的代码块是EMA(EM attention) 模块的代码。

    def forward(self, x):
        idn = x
        # The first 1x1 conv
        x = self.conv1(x)

        # The EM Attention
        b, c, h, w = x.size()
        x = x.view(b, c, h*w)               # b * c * n
        mu = self.mu.repeat(b, 1, 1)        # b * c * k    # k 个 基
        with torch.no_grad():
            for i in range(self.stage_num):  # 迭代T次
                x_t = x.permute(0, 2, 1)    # b * n * c
                z = torch.bmm(x_t, mu)      # b * n * k
                z = F.softmax(z, dim=2)     # b * n * k
                z_ = z / (1e-6 + z.sum(dim=1, keepdim=True)) # 这一步对应论文 sec4.2,reweight X的公式
                mu = torch.bmm(x, z_)       # b * c * k
                mu = self._l2norm(mu, dim=1)  # 为了让基的学习更稳定,并且不改变基的方向,保持基的正交性。(正交是冗余最低的形式)

        z_t = z.permute(0, 2, 1)            # b * k * n
        x = mu.matmul(z_t)                  # b * c * n
        x = x.view(b, c, h, w)              # b * c * h * w
        x = F.relu(x, inplace=True)

        # 跳跃链接
        x = self.conv2(x)
        x = x + idn
        x = F.relu(x, inplace=True)

        return x, mu

整个模型的forward结构如下

 def forward(self, img, lbl=None, size=None):
        x = self.extractor(img)    # backbone
        x = self.fc0(x)            # 降维到512个通道
        x, mu = self.emau(x)       # 经过EMA模块
        x = self.fc1(x)            # seg Head
        x = self.fc2(x)

        if size is None:
            size = img.size()[-2:]
        pred = F.interpolate(x, size=size, mode='bilinear', align_corners=True) # 向原图大小插值。这里不能用label向特征大小差值,因为label在原图空间填充了ignore label,如果对label下采样,会破坏ignore label的值。
        if self.training and lbl is not None:
            loss = self.crit(pred, lbl)
            return loss, mu
        else:
            return pred

还有一个地方值得注意,在EMA模块里,

mu = torch.Tensor(1, c, k) # 512 64
mu.normal_(0, math.sqrt(2. / k)) # Init with Kaiming Norm.
mu = self._l2norm(mu, dim=1)
self.register_buffer(‘mu’, mu)

μ \mu μ的初始化这样的。 μ \mu μ不是一个Parameter,而是一个buffer。对应原文,基的训练方式,究竟是通过反向传播训练还是moving average。
EMANet:Expectation-Maximization Attention Networks for Semantic Segmentation论文解读和代码解读_第3张图片

你可能感兴趣的:(语义分割,深度学习,神经网络,ICCV,深度学习)