ConvNeXt:超越 Transformer?总结涨点技巧与理解代码(附注释)

文章目录

    • 1. 前言
    • 2. 论文学习
      • 2.1. 发展路线
      • 2.2. 训练策略改进
      • 2.3. 宏观设计
        • 2.3.1. stage compute ratio
        • 2.3.2. stem to “Patchify”
      • 2.4. ResNeXt-ify
      • 2.5. 反转瓶颈
      • 2.6. 卷积核大小
      • 2.7. 微小改进
        • 2.7.1. GELU 替换 ReLU
        • 2.7.2. 减少激活层
        • 2.7.3. 减少归一化层
        • 2.7.4. LayerNorm 替换 BatchNorm
        • 2.7.5. 单独的下采样层
    • 3. 代码理解
    • 4. 涨点技巧
    • 5. 结语

1. 前言

近年来,Vision Transformer 在图像分类问题中表现出极大优势。但在目标检测,语义分割和图像恢复等问题中存在诸多不足,因此在该类问题中以 ConvNet 作为先验, Transformer 作为主干网,从而获得优势。

研究者发现这种混合的 ConvNet + Transformer 的网络结构的优势主要归因于 Transformer,而不是 ConvNet 固有的归纳偏置。因此,研究者逐渐改进 ResNet 的结构和训练,旨在探索导致 ConvNet 和 Transformer 性能差异的关键因素。

2. 论文学习

ConvNeXt:超越 Transformer?总结涨点技巧与理解代码(附注释)_第1张图片
论文:https://arxiv.org/pdf/2201.03545.pdf
代码:https://github.com/facebookresearch/ConvNeXt

2.1. 发展路线

该研究梳理了从 ResNet 到类似于 Transformer 的卷积神经网络的发展路线,并根据 FLOPs 考虑两种模型大小,一种是 ResNet-50 / Swin-T ,其 FLOPs 约为 4.5×109,另一种是 ResNet-200 / Swin-B ,其 FLOPs 约为 15.0×109
ConvNeXt:超越 Transformer?总结涨点技巧与理解代码(附注释)_第2张图片
在论文中作者从 ResNet-50 网络结构出发,首先使用训练 Vision Transformer 的类似训练方法进行训练,结果发现相比原始 ResNet-50 的训练方法获得很大的提升。之后作者分别从5个发面升级网络结构: 1)宏观设计; 2)ResNeXt; 3)反转瓶颈; 4)卷积核大小; 5)逐层微小改进。所有的训练和测试都是在 ImageNet-1K 数据集上进行的。

2.2. 训练策略改进

论文对训练策略优化的细节改进相对较多,完整的请自行阅读论文。其中主要改进的地方是

1) 优化器: SGD --> AdamW
2) 权重衰减: 1e-4 --> 5e-2
3) 初始学习率:0.1 --> 4e-3
4) 学习率调整策略:StepLR(step=30, gamma=0.1) --> Cosine
5) batch size: 32 --> 512
6) epoch: 90 --> 300

同时参与对比的除标准 ResNet 外还包括 2021年 timm 和 torchvision 团队对训练策略的研究成果。最后准确率从原始 76.1% 上升至 78.8%。

2.3. 宏观设计

2.3.1. stage compute ratio

将 layer0 到 layer3 的网络块的数量比例由标准的 [3, 4, 6, 3] 改为 Swin-T 中的 [3, 3, 9, 3]。同时在大模型中也与 Swin-T 的 [1, 1, 9, 1] 保持一致。
提升:78.8 --> 79.4

2.3.2. stem to “Patchify”

类似 Swin-T·,论文采用步长为 4 的 4×4 卷积核,从而令窗口之间不想交,每次只处理一个 patch 的特征,同时在卷积层后加入 LayerNorm。
提升:79.4 --> 79.5

2.4. ResNeXt-ify

在这里论文主要基于分组卷积的思想,使用 depth-wise 卷积核替换 bottleneck 中的 3×3 卷积核,令分组数等于通道数。主要思想是每个 depth-wise 卷积核处理一个通道的信息,类似于 Transformer 中的 self-attention。此外,论文将原始的特征图通道数从 64 提升到 96。
提升:79.5 --> 80.5

2.5. 反转瓶颈

标准的 ResNet 中采用的 bottleneck 结构是 [大维度, 小维度, 大维度] 的形式,主要考虑是降低网络的计算量。但后来在 MobileNetV2 中研究者认为这种压缩维度的变换会带来信息损失,因此提出了反瓶颈结构,即 [小维度, 大维度, 小维度] 的形式,类似于 Transformer 中 MLP 的结构。
提升:80.5 --> 80.6

2.6. 卷积核大小

在 Swin-T 中采用了大小为 7×7 的卷积核,为了进行比较在这里作者将 depth-wise 卷积核的大小改为 7×7,同时为了避免反瓶颈结构参数量明显增大的问题将 depth-wise 卷积层放在了 bottleneck 的开始,如图所示。
提升:80.6 --> 80.6
ConvNeXt:超越 Transformer?总结涨点技巧与理解代码(附注释)_第3张图片

2.7. 微小改进

2.7.1. GELU 替换 ReLU

提升:80.6 --> 80.6

2.7.2. 减少激活层

类似于 Transformer 中只存在一次激活层,论文发现在 bottleneck 的两个 1×1 卷积核之间使用激活层,其他位置不适用。同时该实验说明每一级卷积层后并不都是需要激活层的,频繁的非线性映射反而不利于网络特征的学习。
提升:80.6 --> 81.3

2.7.3. 减少归一化层

与减少激活层同理,论文在 bottleneck 中只保留了 1×1 卷积前的一层 BatchNorm。
提升:81.3 --> 81.4

2.7.4. LayerNorm 替换 BatchNorm

一方面 Transformer 中使用 LayerNorm,另一方面一些研究发现 BatchNorm 会对网络性能带来负面影响,论文将 BatchNorm 替换为 LayerNorm。
提升:81.4 --> 81.5

2.7.5. 单独的下采样层

标准的 ResNet 的下采样操作通常由步长为 2 的 3×3 卷积来实现,如果存在跳跃连接则利用步长为 2 的 1×1 卷积来传递低级特征。类似于 Swin-T 中单独的下采样层,论文利用 步长为 2 的 2×2 卷积进行模拟。同时在下采样层后加入 LayerNorm 来避免训练不稳定的问题。
提升:81.5 --> 82.0

3. 代码理解

ConvNeXt 的 block 结构如下:
ConvNeXt:超越 Transformer?总结涨点技巧与理解代码(附注释)_第4张图片
ConvNeXt 的主要代码及注释如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath

class Block(nn.Module):
    r""" ConvNeXt Block. 两个等效的实现:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    使用第二种实现,因为作者发现这种实现在 pytorch 中快一些。
    
    参数:
        dim (int): 输入特征的通道数
        drop_path (float): 随机深度丢弃率,默认为 0.0
        layer_scale_init_value (float): 层缩放的初始值,默认为 1e-6
    """
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depth-wise 卷积
        self.norm = LayerNorm(dim, eps=1e-6)  # 归一化
        self.pwconv1 = nn.Linear(dim, 4 * dim) # 全连接层
        self.act = nn.GELU()  # 激活层
        self.pwconv2 = nn.Linear(4 * dim, dim)  # 反瓶颈结构,中间层维度变大
		# 层缩放
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
		# 随机深度丢弃
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
		# 调整通道顺序,由于后边要用全连接层
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)  # 激活层
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x  # 缩放
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)  # 残差
        return x

class ConvNeXt(nn.Module):
    r""" ConvNeXt

    参数:
        in_chans (int):输入图像的通道数,默认为 3
        num_classes (int): 分类的数量,默认为 1000
        depths (tuple(int)): 每个阶段的 block 个数,默认为 [3, 3, 9, 3]
        dims (int): 每个阶段的特征维度,默认为 [96, 192, 384, 768]
        drop_path_rate (float): 随机深度丢弃率,默认为 0
        layer_scale_init_value (float): 层缩放的初始值,默认为 1e-6
        head_init_scale (float): 分类器权重和偏置的初始化缩放值,默认为 1
    """
    def __init__(self, in_chans=3, num_classes=1000, 
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 
                 layer_scale_init_value=1e-6, head_init_scale=1.,
                 ):
        super().__init__()

        self.downsample_layers = nn.ModuleList()  # 下采样模块
		# patch 化的 stem,步长 4, 卷积核大小 4
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)  # 下采样
        for i in range(3):
			# 步长 2, 卷积核大小 2
            downsample_layer = nn.Sequential(
                    LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),  # 归一化
                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),  # 下采样
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList() # 4 个分辨率阶段,每个阶段包括多个残差块
		# 每一层的随机深度丢弃率
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]  # 每一层的缩放初始值
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # 归一化
        self.head = nn.Linear(dims[-1], num_classes)  # 全连接层

        self.apply(self._init_weights)  # 初始化权重
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):  # 初始化权重方法
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)  # 截断正态分布
            nn.init.constant_(m.bias, 0)  # 常数

    def forward_features(self, x):  # 特征前向传播
        for i in range(4):
            x = self.downsample_layers[i](x)  # 下采样层
            x = self.stages[i](x)
        return self.norm(x.mean([-2, -1])) # 全局平均池化, (N, C, H, W) -> (N, C)

    def forward(self, x):  # 总前向传播
        x = self.forward_features(x)
        x = self.head(x)
        return x

class LayerNorm(nn.Module):  # 归一化
    r""" LayerNorm 支持两种数据格式,对应 block 的两种实现: channels_last (default) or channels_first
    channels_last 对应输入结构为 (batch_size, height, width, channels)
	channels_first 对应输入结构为 (batch_size, channels, height, width)
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))  # 权重
        self.bias = nn.Parameter(torch.zeros(normalized_shape))  # 偏置
        self.eps = eps  # 维稳常数
        self.data_format = data_format  # 输入数据格式
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)  # 自带归一化函数
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)  # 1 维度求均值
            s = (x - u).pow(2).mean(1, keepdim=True)  # 均方差
            x = (x - u) / torch.sqrt(s + self.eps)  # 正态归一
            x = self.weight[:, None, None] * x + self.bias[:, None, None]  # 归一化解结果

4. 涨点技巧

通过对论文的实验结果可以发现,尽管作者尝试了诸多改进,如果考虑再考虑到神经网络训练存在的一些随机性,可以判断所有的改进中真正 work 的不多。因此这里将所有的改进中对图像分类准确率的影响较大的改进点总结如下

1) 网络训练策略
2) 不同阶段 block 的数量
3) patch 化的下采样 stem
4) 分组卷积
5) 减少激活层数量
6) 单独的下采样层

这些改进点是论文中作者在 ResNet-50 上实验结果总结得出,但在 ImagNet-1K 数据集上不能代表在其他数据集上也可以表现良好。同理,在图像分类问题上表现良好不代表在目标检测,语义分割和图像恢复问题上也 work。另外,是否这些技巧在轻量级网络上 work 也尚是未知数。

5. 结语

这篇论文的重点在与将 ResNet-50 通过一步步改进从而在 ImageNet-1K 数据集上获得超越 Transformer 的表现,也从另一方面说明 CNN 的性能还有进一步提升的空间。但目前拥抱 Transformer 已经是大势所趋,是否 CNN 能文艺复兴还需要大佬们继续烧电费研究。

你可能感兴趣的:(pytorch,深度学习,transformer,深度学习,计算机视觉,pytorch,人工智能)