【深度学习:SENet】信道注意力和挤压激励网络(SENet):图像识别的新突破

【深度学习:SENet】信道注意力和挤压激励网络(SENet):图像识别的新突破

    • 为什么有效
    • 如何实现
    • 工作原理
    • 应用案例

【深度学习:SENet】信道注意力和挤压激励网络(SENet):图像识别的新突破_第1张图片

挤压和激励网络(SENets)为卷积神经网络(CNN)引入了一个新的构建模块,该模块能在几乎不增加计算成本的情况下,显著改善通道间的相互依赖性。它们在今年的ImageNet比赛中大放异彩,并帮助将去年的成绩提升了25%。除了巨大的性能提升,SENets还可以轻松集成到现有的架构中。其主要思想是:

为卷积块的每个通道添加参数,让网络能够自适应地调整每个特征图的权重。

这听起来很简单,实际上也确实如此。下面让我们详细了解为什么它如此有效,以及如何用几行简单的代码改进任何模型。

为什么有效

CNN通过其卷积滤波器从图像中提取分层信息。较低层次的网络可以检测简单的上下文特征,如边缘或纹理,而较高层次则能识别人脸、文字或其他复杂的几何形状。它们提取的信息对解决特定任务至关重要。

这些成果是通过融合图像的空间和通道信息实现的。不同的过滤器会首先在每个输入通道中寻找空间特征,然后在所有输出通道中整合这些信息。我在另一篇文章中对此进行了更详细的介绍。

您需要了解的是,在创建输出特征图时,网络对其每个通道的权重是相等的。SENets的目标是改变这一点,通过引入一种内容感知机制,自适应地加权每个通道。最基本的形式是为每个通道添加一个参数,并赋予它一个线性标量值。

然而,作者更进一步。他们首先通过全局平均池化将每个特征图压缩成一个单一数值,全面了解每个通道的信息。这样就产生了一个长度为卷积通道数量的向量。然后,该向量被输入到一个由两层神经网络组成的结构中,输出一个相同长度的向量。这些新的值现在可以作为原始特征图的权重,根据它们的重要性调整每个通道。

如何实现

您可能认为,这听起来并不像我最初承诺的那样简单。但事实上,实现SE块的过程是直接且简洁的。下面是一个示例代码:

def se_block(in_block, ch, ratio=16):
    x = GlobalAveragePooling2D()(in_block)
    x = Dense(ch//ratio, activation='relu')(x)
    x = Dense(ch, activation='sigmoid')(x)
    return multiply()([in_block, x])
  1. 此函数接收一个输入卷积块及其通道数。
  2. 我们使用全局平均池化将每个通道压缩成单个数值。
  3. 通过一个全连接层,加上ReLU激活函数,引入非线性,并降低输出通道的复杂度。
  4. 紧接着是第二个全连接层,使用Sigmoid激活函数,为每个通道提供平滑的门控功能。
  5. 最后,我们根据侧边网络的输出对卷积块的每个特征图进行加权。

这五个步骤几乎不会增加额外的计算成本(少于1%),且可以轻松添加到任何模型中。

Vanilla ResNet模块与集成了SE块的ResNet模块对比图

【深度学习:SENet】信道注意力和挤压激励网络(SENet):图像识别的新突破_第2张图片

Vanilla ResNet 模块与建议的 SE-ResNet 模块

作者指出,通过向ResNet-50添加SE块,可以达到与ResNet-101几乎相同的精度。这对于计算成本仅为一半的模型来说是令人印象深刻的。本文还探讨了其他架构,如Inception、Inception-ResNet和ResNeXt,这些架构通过集成SE块在ImageNet上展现了更低的错误率。

【深度学习:SENet】信道注意力和挤压激励网络(SENet):图像识别的新突破_第3张图片

SENets 如何改进现有架构

SENets令人惊讶的是其简单性和有效性。能够几乎免费地将这种方法应用于任何模型,应该会激励您重新审视并训练您之前构建的所有模型。

工作原理

SENet主要通过以下三个步骤实现其目标:

  1. 挤压(Squeeze): 通过全局平均池化层将空间维度的信息压缩为一个通道描述符,这有助于凸显每个通道的全局重要性。

  2. 激励(Excitation): 使用全连接层学习通道间的依赖关系,并应用sigmoid激活函数以获取每个通道的重要性权重。

  3. 重标定(Re-calibration): 将计算得出的权重与原始特征图相乘,实现对特征图的动态调整。

应用案例

SENet已在图像识别、目标检测和图像分割等多种任务中得到广泛应用,并显著提高了性能。例如,在图像识别任务中,通过集成SE模块,模型能更准确地识别图像中的关键对象。

你可能感兴趣的:(深度学习知识库,深度学习,网络,人工智能)