【MMAction2】学习笔记(1)——训练数据读取过程pipeline

【MMAction2】学习笔记(1)——训练数据读取过程pipeline_第1张图片


目录

  • 0. MMAction2 介绍
  • 1. Over View of Pipeline
  • 2. Data Loading
  • 3. Pre-processing & Formatting


0. MMAction2 介绍

  MMAction2 是一款基于 PyTorch 框架的视频理解开源工具箱,是由商汤科技和港中文大学联合提出的 OpenMMLab 项目的成员之一。
  视频理解是计算机视觉中重要的研究方向,近年来逐渐成为业界和学术界的研究热点,同时也被广泛应用在智能监控/视频推荐等领域中。OpenMMLab 项目开源了 MMAction2,这是一套基于 PyTorch 实现的视频理解工具箱和 benchmark,目前包含了视频理解领域常见的任务,比如动作识别时序动作检测时空动作检测等。

  • 模块化设计: MMAction2 将统一的视频理解框架解耦成不同的模块组件,通过组合不同的模块组件,用户可以便捷地构建自定义的视频理解模型。
  • 支持多种任务和数据集: MMAction2 支持多种视频理解任务,包括动作识别,时序动作检测,时空动作检测以及基于人体姿态的动作识别,总共支持 27 种算法和 20 种数据集
  • 详尽的单元测试和文档: MMAction2 提供了详尽的说明文档,API 接口说明,全面的单元测试,以供社区参考

1. Over View of Pipeline

  类似于其他OpenMMLab项目,MMAction2 使用 Dataset 和 DataLoader 实现训练/测试阶段的数据加载。 由Dataset 返回一个字典Dict作为模型的输入。 而数据准备阶段pipeline和dataset类是相互解耦的,dataset类一般用于定义数据标注annotations,而数据pipeline用于加载数据(对于视频格式重点在于帧的采样)、预处理格式化。一个完整的数据pipeline结果如下图所示。
【MMAction2】学习笔记(1)——训练数据读取过程pipeline_第2张图片
  图中下侧IDE中显示的变量transform(list)读取了由config文件设置好的pipeline流程,同时也包含了实例化所需类的具体参数设置。图上侧是一个经典的pipeline结构。蓝色块是pipeline操作。随着其的深入,每个操作都向结果字典添加新键(标记为绿色)或更新现有键(标记为橙色)。而对于其三个主要部分:

  • 在加载数据部分,以裁片完成的数据集格式为例,对应数据集格式是rawframes。具体是基于SampleFrames类和RawFrameDecoder类实现的,其功能分别是获取需要的采样帧索引index和根据传递的索引来加载数据。代码位置为 mmaction/datasets/pipelines/loading.py
  • 在数据预处理部分,mmaction2内置多种数据增强的方法,以图中为例。具体是基于RandomRescaleRandomCropFlipNormalize类实现的。其作用分别是对裁剪后的图像序列进行随机尺寸变换、随机裁剪、随机翻转以及根据均值和方差对数据进行标准化。代码位置为mmaction/datasets/pipelines/augmentations.py
  • 在数据格式化部分,需要对图像数据按照config当中的设定进行格式上的变换并最终作为后续网络模型的输入。具体是基于FormatShape类、Collect类和ToTensor类实现的。其作用分别是将时序图像序列按照规定要求划分维度、将数据信息整合生成meta键(包含原始数据的各类信息)和将图像序列转化为torch框架训练所需的tensor格式。代码的位置为mmaction/datasets/pipelines/formatting.py

config配置文件及注释如下:

dataset_type = 'RawframeDataset'  # 数据集类型,RawframeDataset对应已完成视频切片的数据集
data_root = 'data/thyroid_CEUS/rawframes' # 切片图像存放地址

img_norm_cfg = dict( # 预先设定数据集的均值和方差,供后续normalize计算使用
    mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_bgr=False)

train_pipeline = [
    dict(type='SampleFrames', clip_len=8, frame_interval=32, num_clips=1),# 使用SampleFrames类进行索引选取,参数后续会解读
    dict(type='RawFrameDecode'), # 选取RawFrameDecoder根据索引来获取img图像序列
    dict(type='RandomRescale', scale_range=(256, 320)), # 尺寸调整范围设定
    dict(type='RandomCrop', size=224), # 裁剪对应尺寸设定
    dict(type='Flip', flip_ratio=0.5), # 翻转概率设定
    dict(type='Normalize', **img_norm_cfg), # 对应上面的标准化设定
    dict(type='FormatShape', input_format='NCTHW'), # 数据维度格式 (N,C,T,H,W)
    dict(type='Collect', keys=['imgs', 'label'], meta_keys=[]), # 编写meta元信息设定
    dict(type='ToTensor', keys=['imgs', 'label']) # 转化为tensor设定
]

2. Data Loading

【MMAction2】学习笔记(1)——训练数据读取过程pipeline_第3张图片
  对于SampleFrames即用于确定采样帧索引index的方法,在实例化类时三个定义的参数为Clip_lenNum_clipsFrame_interval。将原始视频格式数据分分块为Clip_len个视频片段,而每个视频片段由Num_clips个帧组成,而每个帧采样的间隔为Frame_interval。在进入调用SampleFrames后,首先跳转函数当中的_sample_clips_get_train_clip函数,根据总帧数和上述三个参数为代表的采样策略来确定偏移clip_offsets,而后回到SampleFrames的调用函数_ call _中输出采样帧的index索引。对应的是n维数组格式。
  得到包含索引index的result后输入给类RawFrameDecode的调用函数_ call _,根据frame_inds来采样帧,最终返回由所选帧组成的列表imgs,格式为数组array[(H,W,C)…]。


3. Pre-processing & Formatting

  对于图像预处理(Pre-processing)环节,以下图为例。mmaction2当中给出了多种数据增强-预处理方法,首先采用RandomResize类对图像尺度进行变换,而后通过RandomCrop类对图像进行范围内随机裁剪,接着根据设定概率依靠Flip类对图像施加翻转。在完成上述在线增强操作后,依据config设定当中的均值方差初始化Normalize类对图像进行标准化:

for img in imgs: # 遍历列表,对图像序列统一处理
    mmcv.imnormalize_(img, self.mean, self.std, self.to_bgr) # 标准化

每个步骤新增/删除/更新的Dict字段如图3所示,值得注意的是,针对每个视频所采样的序列,预处理是针对视频数据中所有帧进行的。
【MMAction2】学习笔记(1)——训练数据读取过程pipeline_第4张图片

  对于格式化(Formatting)环节,FormatShape类根据config文件中定义的类别对维度进行划分,以 NCTHW 为例,根据(Num_clips*Num_crops,Channels,Frames,Height,width)来重构维度

if self.input_format == 'NCTHW':
    num_clips = results['num_clips']
    clip_len = results['clip_len']

    imgs = imgs.reshape((-1, num_clips, clip_len) + imgs.shape[1:])
    # N_crops x N_clips x L x H x W x C
    imgs = np.transpose(imgs, (0, 1, 5, 2, 3, 4))
    # N_crops x N_clips x C x L x H x W
    imgs = imgs.reshape((-1, ) + imgs.shape[2:])
    # M' x C x L x H x W
    # M' = N_crops x N_clips

Collect类从加载器收集与特定任务相关的数据。而ToTensor类通过调用to_tensor函数实现将数据由数组ndarray转化为tensor格式供后续训练。至此,一个数据pipeline基本构成。

你可能感兴趣的:(学习,pytorch,深度学习,python)