《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析

论文地址:https://arxiv.org/pdf/1811.08383.pdf
代码地址:https://github.com/mit-han-lab/temporal-shift-module
  TSM,简而言之,是一个视频理解网络,能够以2D卷积的计算量实现3D卷积的效果。下面将会根据代码和论文帮助大家理解这个网络。

视频理解概述:

       目前有两种主流方法:2D CNN和3D CNN。2D CNN:1.双流网络(RGB和光流) 2.TSN(Temporal Segment Networks)。3D CNN :C3D。

数据预处理:

       somethingv1,v2 和kinetics的数据预处理在官方代码的tools路径下都有,但是这两个数据集过于庞大,于是选择了相对较小的UCF101,UCF101数据集的下载方法及处理方法可以使用PaddleVideo的工具,链接。
       数据下载完以后,应该是下面的结构:
《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第1张图片
annotations/classInd.txt是类别ID和类别名
《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第2张图片
rawframes是从视频中提取出的图片。结构是这样的:
《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第3张图片
ucf101_train_split_1_rawframes.txt,ucf101_val_split_1_rawframes.txt是划分的训练集和验证集的列表。
《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第4张图片
第一列是各个视频图片切片的地址,第二列是这个视频包含的图片切片的数量,第三列是这个视频的label,是从0开始算的,这一点和classInd.txt不同。

预训练模型:

链接,预训练模型一定要下载,不然跑不出来跟官方一样的结果。

网络与代码解析

《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第5张图片
       TSM的原理其实用上面的一张图就可以完全说的清,上面的b,c分别代表两种模式,离线模式和在线模式,分别用于跑离线视频和在线视频,离线视频可以将未来帧、当前帧、过去帧信息融合,在线视频只能将过去帧和当前帧融合,非常容易理解。每一行代表同一帧的不同channel, 那么在线模式就是将channel的 [0:partial]区间整体往前移一帧,那么第一帧用0 PAD,[partial:2*partial]的区间整体往后移一帧,那么最前面一帧的特征用0 pad,其余Channel保持不变。partial可以调节,例如如果等于1/8 channel数,那么有1/4的channel发生移动。离线模式则都往前移一帧。
《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第6张图片
作者通过做实验发现移动1/4channel(离线模式是包括前移1/8,后移1/8)效果最好(精度和latency)。
此外,作者还介绍了两种模式,inplace tsm 和 residual tsm。
《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第7张图片
作者说明residual tsm效果比较好,因为它不仅有temporal shift 分支,还保留了原有的分支。使用residual tsm会导致部分当前帧的信息丢失。
temoral shift的代码在ops/temporal_shift.py中。

def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        :return:
        """
        super(TSN, self).train(mode)
        count = 0
        if self._enable_pbn and mode:
            print("Freezing BatchNorm2D except the first one.")
            for m in self.base_model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    count += 1
                    if count >= (2 if self._enable_pbn else 1):
                        m.eval()
                        # shutdown update in frozen mode
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False

TIPS

作者freeze了除了第一个的所有BN层的参数,使用全局的std和mean,因为做卷积的时候是将不同batch的clip一起做的,这样如果整个batch去做BN的话,会融合其他batch clip的信息,可能会一定程度湮没本batch不同clip之间的信息交换(个人看法)。

def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        :return:
        """
        super(TSN, self).train(mode)
        count = 0
        if self._enable_pbn and mode:
            print("Freezing BatchNorm2D except the first one.")
            for m in self.base_model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    count += 1
                    if count >= (2 if self._enable_pbn else 1):
                        m.eval()
                        # shutdown update in frozen mode
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False

简单做了个实验,将freeze BN去掉。在ucf101上的实验结果从95.85降到了 93.319

实验结果:

《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第8张图片
offline能够有和online接近的精度

《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析_第9张图片
和3D模型比,TSM的精度和FLOPS都有优势。

你可能感兴趣的:(《TSM: Temporal Shift Module for Efficient Video Understanding》 论文与代码解析)