mmdetection(5):GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond 理解

这篇文章让我很是疑惑,已经跑离作者的思路了,
mmdetection(5):GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond 理解_第1张图片

欧克,那我们先来看一下把, a模型是NL block ,即non-local ,我们前文讲过,意思就是通过操作 i 和 j 点的特征,来判断其是否应该相连,然后在乘加上 j 点的特征,最后得到 i 点卷积的结果。那wh*wh就表示相连关系啊。

那我们来梳理一下啊,NL我们到底要个什么东西,很明显我们的要的是卷积的结果啊,而卷积的是 各个 j 点啊,类似普通卷积的上下左右9个点。该模块是代替的卷积核的作用。

GCnet这篇文章讲了什么呢,在NL中我们用的是hw和hw,表示的是我们特征图上的每个点都有各自不同的关注位置,
而本文经过实验证明每个点的关注位置都是一样的,那我们可以简化一下,只用一个图hw来表示注意力机制,那这样就没发用hw*hw 的形式了,只能仅仅根据 j 点的特征值来判断我们是否需要 j 点的信息。

讲到这里大家可能都能理解了,但是有个问题啊,这样做相当于每个点都是取的同几个位置啊,模仿NL的做法,我们应该首先计算attention图,然后对其卷积,得到某一点的值,然后不啦不啦不啦,我编不下去了,最后得到的是每个点的值都一样,这不是扯吗?

mmdetection(5):GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond 理解_第2张图片
但是该文章最后整出来了个啥,计算关注点位置,然后看哪一个channel的关注点信息更多,给其增大权重,senet?随便把,大佬的世界搞不懂。

不过代码还是要看的
简单的很,就是首先计算特征,然后输出一通道的权重图,即context_mask,然后softmax一下,乘上输出。得到context。

def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context


self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)

你可能感兴趣的:(检测算法-深度学习)