DBNet++(TPAMI) 原理与代码解析

paper:Real-Time Scene Text Detection with Differentiable Binarization and Adaptive Scale Fusion

code1:https://github.com/MhLiao/DB

code2:https://github.com/open-mmlab/mmocr

本文的创新点

本文是对DBNet的改进,关于DBNet的介绍具体可见场景文本检测算法 可微分二值化DBNet原理与代码解析,本文新提出了一种自适应尺度融合模块Adaptive Scale Fusion(ASF)module来自适应地融合多尺度的特征,将ASF应用于分割网络,显著地增强了其检测不同尺度文本实例的能力。

方法介绍

DBNet++的完整结构如下图所示

DBNet++(TPAMI) 原理与代码解析_第1张图片

其中,在FPN的多层输出和最终的预测特征图之间加入了ASF module。

ASF的完整结构如下图所示

DBNet++(TPAMI) 原理与代码解析_第2张图片

FPN的输出为 \(X\in \mathcal{R}^{N\times C\times H\times W}=\left \{ X_{i} \right \}_{i=0}^{N-1} \),其中 \(N=4\) 表示FPN的4个不同尺度的输出特征,通过插值得到了一致的spatial size。首先将 \(X\) 沿通道concatenate然后通过一个 \(3\times 3\) 的卷积层得到中间特征 \(S\in \mathcal{R}^{C\times H\times W}\)。然后,\(S\) 经过一个空间注意力模块spatial attention module得到注意力权重 \(A\in \mathcal{R}^{N\times H\times W}\)。接着,权重 \(A\) 沿通道维度均分为 \(N\) 份,并与相应的特征加权相乘得到最终的融合特征 \(F\in \mathcal{R}^{N\times C\times H\times W}\)。

scale attention的完整过程定义如下

DBNet++(TPAMI) 原理与代码解析_第3张图片

代码解析

这里以mmocr的实现为例,注意在文章中作者提出的ASF是一个spatial attention模块,但在官方实现https://github.com/MhLiao/DB/blob/master/decoders/feature_attention.py中,作者给出了三种不同注意力机制的实现,除了文章中提到的spatial attention,还有channel attention以及两者结合的spatial-channel attention。MMOCR只移植了spatial-channel attention的实现即ScaleChannelSpatialAttention,具体如下

class ScaleChannelSpatialAttention(BaseModule):
    """Spatial Attention module in Real-Time Scene Text Detection with
    Differentiable Binarization and Adaptive Scale Fusion.

    This was partially adapted from https://github.com/MhLiao/DB

    Args:
        in_channels (int): A numbers of input channels.
        c_wise_channels (int): Number of channel-wise attention channels.
        out_channels (int): Number of output channels.
        init_cfg (dict or list[dict], optional): Initialization configs.
    """

    def __init__(
        self,
        in_channels: int,  # 256
        c_wise_channels: int,  # 64
        out_channels: int,  # 4
        init_cfg: Optional[Union[Dict, List[Dict]]] = [
            dict(type='Kaiming', layer='Conv', bias=0)
        ]
    ) -> None:
        super().__init__(init_cfg=init_cfg)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # Channel Wise
        self.channel_wise = Sequential(
            ConvModule(
                in_channels,
                c_wise_channels,
                1,
                bias=False,
                conv_cfg=None,
                norm_cfg=None,
                act_cfg=dict(type='ReLU'),
                inplace=False),
            ConvModule(
                c_wise_channels,
                in_channels,
                1,
                bias=False,
                conv_cfg=None,
                norm_cfg=None,
                act_cfg=dict(type='Sigmoid'),
                inplace=False))
        # Spatial Wise
        self.spatial_wise = Sequential(
            ConvModule(
                1,
                1,
                3,
                padding=1,
                bias=False,
                conv_cfg=None,
                norm_cfg=None,
                act_cfg=dict(type='ReLU'),
                inplace=False),
            ConvModule(
                1,
                1,
                1,
                bias=False,
                conv_cfg=None,
                norm_cfg=None,
                act_cfg=dict(type='Sigmoid'),
                inplace=False))
        # Attention Wise
        self.attention_wise = ConvModule(
            in_channels,
            out_channels,
            1,
            bias=False,
            conv_cfg=None,
            norm_cfg=None,
            act_cfg=dict(type='Sigmoid'),
            inplace=False)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        Args:
            inputs (Tensor): A concat FPN feature tensor that has the shape of
                :math:`(N, C, H, W)`.

        Returns:
            Tensor: An attention map of shape :math:`(N, C_{out}, H, W)`
            where :math:`C_{out}` is ``out_channels``.
        """
        # (4,256,160,160)
        out = self.avg_pool(inputs)  # (4,256,1,1)
        out = self.channel_wise(out)  # (4,256,1,1)
        out = out + inputs  # (4,256,160,160)
        inputs = torch.mean(out, dim=1, keepdim=True)  # (4,1,160,160)
        out = self.spatial_wise(inputs) + out  # (4,1,160,160)+(4,256,160,160)->(4,256,160,160)
        out = self.attention_wise(out)  # (4,4,160,160)

        return out

这里设batch_size=4,input_size=(640, 640),FPN的4层输出经过上采样后得到统一大小的feature map,即列表[(4,64,160,160),(4,64,160,160),(4,64,160,160),(4,64,160,160)],然后沿通道拼接得到shape=(4,256,160,160)的输出,然后经过一个3x3的卷积层输出shape不变还是(4,256,160,160)得到ASF模块的输入。

首先经过全局平均池化得到(4,256,1,1)的输出,通道注意力模块self.channel_wise是一个两层卷积conv1x1-64-ReLU-conv1x1-256-Sigmoid得到大小不变的输出即通道注意力的权重,然后与原始输入相加。接着沿通道取均值,接着经过空间注意力模块即self.spatial_wise,它也是两层卷积conv3x3-1-ReLU-conv1x1-1-Sigmoid得到空间注意力的权重再与输入相加,最后经过conv1x1-4-Sigmoidself.attention_wise得到ASF模块的输出(4,4,160,160)

然后将ASF模块输出的4层注意力权重与原始FPN对应的4层输出进行加权相乘,最后再沿通道拼接得到最终输出。

for i, out in enumerate(outs):
    enhanced_feature.append(attention[:, i:i + 1] * outs[i])
out = torch.cat(enhanced_feature, dim=1)

你可能感兴趣的:(文本检测,深度学习,人工智能,文本检测,OCR)