mmaction2 行为识别模型相关源码

文章目录

    • 0. 前言
    • 1. 模型创建过程详解
    • 2. 基本模型详解
      • 2.1. `BaseRecognizer` 源码详解
      • 2.2. `Recognizer2D` 源码详解
      • 2.3. `Recognizer3D` 源码详解
    • 3. TSN与TSM的实现
      • 3.1. TSN 的实现
      • 3.2. TSM 的实现
    • 4. I3D/R(2+1)D/Slow/SlowFast 的实现
      • 4.1. I3D的实现
      • 4.2. R(2+1)D的实现
      • 4.3. SlowFast的实现
      • 4.4. Slow 的实现


0. 前言

  • mmaction2 目前支持行为识别模型以及时序行为检测模型。
    • 行为识别模型包括TSN/TSM/I3D/R(2+1)D/Slow/SlowFast。
      • 本文将详细描述。
    • 时序行为检测模型包括BMN/BSN。
      • 本文不涉及。
  • 从宏观角度描述下 mmaction2 模型相关源码
    • 模型相关源码结构:
      • 对于每一类模型(行为识别模型、时序行为检测模型)都有若干基本模型(如行为识别模型中的2D模型与3D模型)。
      • 基本模型定义了模型的基本组成部分(称为组件)以及数据的基本流向。
        • 3D行为识别模型 Recognizer3D 的基本组件就包括了 backbone 与 head。
        • 3D行为识别模型 Recognizer3D 的数据基本流向就是,原始图片通过backbone提取特征,提取到的结果作为head的输入,最终得到分类结果。
        • 换句话说,在定义了基本模型后,我们要做的工作就是定义各个组件的具体类型。
      • 所有具体模型(如I3D/SlowFast)就是在基本模型的基础上,指定了所有组件的具体类型以及相关参数。
        • 如I3D模型,就是指定了backbone为 ResNet3d,head为I3DHead
      • 源码剩下的阅读任务就是研究各个组件的具体实现过程。
    • train/val/test时构建模型的过程:
      • 简单说就是搭积木。
      • 仔细点说就是在配置文件中定义好基本模型与组件的类型与参数,然后在入口函数中根据配置文件创建具体模型
      • PS:源码看过了就没花头,mmdetection/detectron2/slowfast/mmaction都是一个套路。
  • 本文剩余部分分为两类:
    • 第一章:基本模型以及模型创建,分析基本模型的细节以及模型创建的细节
    • 第二章:具体模型实现,分析I3D/R(2+1)D/Slow/SlowFast/TSN/TSM的具体实现细节。

1. 模型创建过程详解

  • 根据前面的介绍,模型创建总体过程就是:根据配置文件中的内容搭积木,通过基本模型+具体组件,最终组成具体模型
  • 使用了Registry机制
    • 看过 mmdetection/detectron2/slowfast/mmaction 源码的应该都非常熟悉这种机制了。
    • 对于每类积木(如行为识别基本类型、backbone、head)都对应一个 Registry 对象。
    • 每个Registry对象主要实现两个功能:
      • 功能一:维护一个字典,key为字符串形式的类型名称,value为类。
        • 例如,backbone的Registry对象,key就是字符串如 "ResNet3d",value就是同名class对象(可通过dict[key](...)创建一个ResNet3d类的对象)
      • 功能二:可通过注解新增 key-value 对。
    • 对于所有相关类,都通过注解进行注册。
    • 对于行为识别模型,相关的Registry对象包括 BACKBONES, HEADS, RECOGNIZERS
  • 入口函数:mmaction.models.builder.py 中的 build_model
  • 具体创建过程:
    • 第一步:根据配置文件中的 cfg.model.type 字符串,在 RECOGNIZERS 中选择对应的 基本类型
    • 第二步:将 cfg.model 中除 type 外的参数以及 cfg.train_cfg/cfg.test_cfg 中的参数作为 基本类型 的初始化参数传入。
    • 第三步:在基本类型的初始化过程中,就会通过传入的参数构建具体组件
      • 构建具体组件的过程其实也就是通过相关参数选择对应Registry对象,然后构建,没什么好多说的了。

2. 基本模型详解

  • 行为识别模型的基本类型包括两类 Recognizer2DRecognizer3D,都继承了 BaseRecognizer

2.1. BaseRecognizer 源码详解

  • 所有行为识别模型都应该继承该类。
  • 在继承该类时,应重写 forward_trainforward_test 方法,分别表示训练/预测过程。
    • 两者的输入都是 imgs,shape应该是 B, T, C, H, WB, C, T, H, W 等,看数据预处理怎么定义的。
    • 前者的输出是losses,后者的输出是分类结果。
  • 定义了 train_stepval_step,前者定义了获取losses的过程,后者定义了获取预测结果的过程。
    • 封装了 forward 方法。
    • 虽然输入参数中包括了 optimizer,但方法中只获取了损失函数,并没有进一步进行梯度下降。
  • 损失函数结果预处理,即 _parse_losses 方法
    • 将结果构造为一个字典。
    • 新增总损失函数 loss
    • 处理分布式训练时的问题,集合所有loss的值。
  • 平均所有clip的结果,即 average_clip 方法
    • 看源码,相关参数在 test_cfg 中,是测试时专用的吗?
    • 可能是取了多个crop然后平均一下结果?
    • 不确定,等以后用到相关功能的时候再说吧。

2.2. Recognizer2D 源码详解

  • TSN/TSM 继承了该类。
  • 看了下数据相关源码,在TSN/TSM中使用的输入数据format都是 NCHW.
  • 在train/test中都对数据进行了reshape
    • 主要目标就是将 BATCH_SIZE, N, C, H, W 的数据转换为 BATCH_SIZExN, C, H, W
    • 毕竟,对于2D网络来说,对每张图片都要用2D CNN来提取特征。
    • 换句话说,2D网络需要输入数据是4维的。
  • 另外,不同于3D网络,在train/test过程中都获取了 num_segs 参数作为 head 的输入。
  • 有一个用于计算FLOPs的forward_dummy函数,后续看FLOPs相关源码时再说。

2.3. Recognizer3D 源码详解

  • I3D/R(2+1)D/Slow/SlowFast 继承了该类。
  • 看了下数据相关源码,输入数据主要用的都是 NCTHW 形式。
  • 在train/test过程中都对数据进行了reshape
    • 主要目标是将 BATCH_SIZE, M', C, T, H, W 形式的数据转换为 BATCH_SIZExM', C, T, H, W
    • 对于3D网络来说,需要的数据输入形式就是5维的。
  • 与2D网络不通,这里并不需要 num_segs
    • 感觉这个参数就类似于 T 维,3D网络中直接处理了。
  • 有一个用于计算FLOPs的forward_dummy函数,后续看FLOPs相关源码时再说。

3. TSN与TSM的实现

3.1. TSN 的实现

  • 使用了 Recognizer2D 作为基础类型,backbone选择了 ResNet,head 选择了 TSNHead
  • backbone没啥要说了,普通ResNet。
  • 对于TSNHead稍微多说几句:
    • 其实了解TSN的应该知道要做啥。
    • 输入的特征图尺寸其实是 N * num_segs,即包括了 batch size 以及一个clip中的T帧图片。
    • TSN 做的工作就是对每个clip的 num_segs 帧结果取平均,得到最终结果。
    • 做的工作就是 N * num_segs, in_channels, h, w 经过reshape与avg pool得到 N, inchannels 的特征,然后通过一个全连接层进行分类得到最终结果。如果有必要的话,再加上一个dropout。

3.2. TSM 的实现

  • 使用了 Recognizer2D 作为基础类型,backbone选择了 ResNetTSM,head 选择了 TSMHead
  • 对于Backbone,与普通ResNet的区别就在于,对所有block的的conv1添加了 shift 操作。
  • Shift操作的具体实现
    • 在TSM作者提供的源码中,Shift操作主要通过slice赋值操作实现,如 out[:, :-1, :fold] = x[:, 1:, :fold]。这些操作在onnx/TVM转换的时候存在问题。
    • mmaction2的作者使用了分别获取每一块,然后concat得到最终结果(而不是slice赋值),这样可能onnx等转换的时候方便一点。
      • 这个原因是我猜的,暂时还没进行onnx转换啥的。
      • 之前转换原版TSM源码时候,经常出现的错误是 fold = c // shift_div 这里出错,mmaction2中还是保留有这个,不知道转换起来有没有什么问题。
    • 另外,mmaction的实现中没有使用 torch.zeros(), torch.zeros_like() 操作,好像是caffe inference 的问题,不过我没碰到过。
  • TSMHead 的实现细节
    • TSN是直接先将结果转换为 N, in_channels 再进行fc。
    • TSM中则是在 N * num_segs, in_channels 中就计算fc,得到结果再进行avg。
  • temporal pool 功能
    • 在看原版代码的时候,temporal pool 执行的操作就是在ResNet的layer2钱增加了一个 T 纬度上的 3,1,1/2,1,1 的max pooling操作。
    • 即,backbone 中 stage2/3/4 的 num_segments 数值减半。
    • TSMHead 也需要注意 num_classes 的取值。
    • mmaction2 的实现好像有问题,明天试一下。

4. I3D/R(2+1)D/Slow/SlowFast 的实现

4.1. I3D的实现

  • 使用了 Recognizer3D 作为基础类型,backbone选择了 ResNet3d,head 选择了 I3DHead
  • backbone 的整体结构与 ResNet 完全相同
    • 包括stage数量,block结构与数量,conv/bn/relu的数量。
  • backbone 中 ResNetResNet3d 的不同之处在于:
    • 所有2D BN和2D CNN都转换为3D BN和3D CNN。
    • CNN多了一维temporal的,那就多了对应的kenrel size与stride。
  • backbone 的具体变化,即3D卷积的kernel size与stride
    • STEM中的变化
      • 卷积从原先的7x7/stride(2,2)改为5x7x7/stride(2,2,2)
      • max pooling从原先的3x3/stride(2,2)改为1x3x3/stride(2,2,2)
    • stage总体变化:
      • 原本四个stage的stride(都是空间)是(1,2,2,2),现在分为时间、空间两个维度,时间上stride为(1,1,1,1),空间上维度与之前相同,为(1,2,2,2)。
      • inflate相关
        • 本质就是 temporal 维度上kernel size的变化,stride都是1。
        • 所谓的 inflate 翻译应该就是膨胀的意思,好像是通过2D卷积实现类似3D卷积的功能(但看源码好像不是这个意思,具体看下面的实现)。
        • 换句话说,在inflate模式下,一次3x3x3的卷积需要转换为3x1x1+1x3x3两个卷积实现。
        • 参数包括inflate_freq与inflate_stype,前者是每个block都有对应的参数(判断当前block是否需要进行inflate操作),后者表示inflate类型。
    • inflate的具体实现:
      • 对于BasicBlock有两种模式:inflate模式与非inflate模式
        • inflate模式下第两个卷积都使用3x3x3的卷积核。
        • 非inflate模式下,两个卷积都使用1x3x3的卷积核。
      • 对于Bottleneck有三种模式:非inflate模式,inflate 3x1x1模式,inflate 3x3x3 模式
        • 非inflate模式:1x1x1+1x3x3+1x1x1
        • inflate 3x1x1模式(最常用):3x1x1+1x3x3+1x1x1
        • inflate 3x3x3模式:1x1x1+3x3x3+1x1x1
  • I3DHead 实现的功能非常简单
    • 先将输入的 N, in_channels, T, H, W 通过 avg pool 转换为 N, in_channels
    • 然后经过dropout+fc,得到分类结果。

4.2. R(2+1)D的实现

  • 使用了 Recognizer3D 作为基础类型,backbone选择了 ResNet2Plus1d,head 选择了 I3DHead
  • 趁这个位置,说说 mmaction2 中模型构建相关代码中不太一样的地方。
    • 比如要定义一个普通2D卷积操作,mmaction2中不使用 torch.nn.Conv2d 这样的默认API,而是会使用 mmcv.cnn.ConvModule
    • mmcv.cnn.ConvModule 会根据输入的 conv_cfgnorm_cfg 构建对应的卷积操作。
    • 能够创建的操作包括 Conv1D/Conv2D/Conv3D
  • 为了实现R(2+1)D,mmaction2 定义了一个新的 conv_cfg 参数,即 Conv2plus1d.
    • 该参数的具体实现位于 mmaction.models.common.conv2plus1d.py 中的 Conv2plus1d
  • ResNet3d(I3D)与ResNet2Plus1d(R(2+1)D)之间的异同
    • 将I3D中的所有 Conv3d 转换为 Conv2plus1d + 忽略 I3D 中的 pool2,这样的结果就是 R(2+1)D。

4.3. SlowFast的实现

  • 使用了 Recognizer3D 作为基础类型,backbone选择了 ResNet3dSlowFast,head 选择了 SlowFastHead
  • SlowFast 的配置与其他的略有不同
    • 需要配置两个 ResNet3dSlowFast 对象,分别表示 Slow 分支和 Fast 分支。
    • ResNet3dSlowFast 对象继承自 ResNet3d,后续会单独介绍。
    • Slow分支还包含了lateral分支,即特征融合分支。特征融合相关后面会单独介绍。
  • ResNet3dSlowFast 简介
    • 该对象继承了 ResNet3d
    • 主要区别在于增加了 lateral 相关内容,即SlowFast融合相关源码。
    • Slow分支
      • 采样率由输入数据与self.resample_rate决定。
      • 包含 lateral 相关内容。
      • STEM的时间维度相关的kernel size与conv/pool stride都是1。
      • 4 stages中spatial stride都分别是(1,2,2,2),每个stage中不同block对应的inflate参数都相同,四个stage的 inflate freq为(0,0,1,1),都是3x1x1形式的。
      • 总体channels数量也与普通I3D相同。
    • Fast分支
      • 采样率由输入数据与self.resample_rate//self.speed_ratio决定,后者一般为1。
      • 不包含 lateral 相关内容。
      • STEM的时间维度相关的conv kernel size为5,pool/pool stride都是1。
      • 4 stages中spatial stride都分别是(1,2,2,2),每个stage中不同block对应的inflate参数都相同,四个stage的 inflate freq为(1,1,1,1),都是3x1x1形式的。
  • lateral 详解
    • 作用:融合slow与fast分支。
    • 前提:slow与fast除了channel数量外,其他结构基本都是相同的
    • 基本做法就是将fast分支中某个位置的特征经过3D卷积转换,然后与同一层的slow分支进行concat操作,concat后结果作为slow分支的输出。
    • 3D卷积的实现细节:kernel size为 (5,1,1), stride为 (alpha,1,1),padding为(2,0,0),channel数量x2。
  • SlowFastHead 的实现
    • 其实就是先分别将 slow 与 fast 分支的结果在T, H, W纬度进行 avg pool。
    • 对avg pool结果concat后执行 dropout+fc。

4.4. Slow 的实现

  • 使用了 Recognizer3D 作为基础类型,backbone选择了 ResNet3dSlowOnly,head 选择了 I3DHead

  • 其实我没看出来与I3D有什么区别……没一行一行对着看,感觉没啥区别……

  • 好像是参数 with_pool2 不起作用了?

  • 没证没啥大的不同

你可能感兴趣的:(PyTorch)