[mmaction2版本] 视频分类(二) TIN:Temporal Interlacing Network 原理及代码讲解

接着上一篇文章TSM视频理解, 今天介绍新的视频分类网络TIN(Temporal Interlacing Network)。相对于TSM,TIN可以更灵活的基于交错网络预测出我们的特征图的随着时间的偏移量值而不是向TSM每次移动一位进行特征融合,如果不了解的具体可以通过下面的内容进行理解。我们仍然基于mmaction框架进行讲解。
paper: Temporal Interlacing Network
code: mmaction2

一、 原理介绍


该模型是基于时间交错网络进行行为识别,当时其在速度上实现了SOTA 6倍的加速,同时在准确率上实现了4%的提升。该方法的思想和TSM的思想一样都是希望将时间信息嵌入到空间信息特征去,以便可以同时一次同时联合学习两个域中的信息。作者发明此网络的直觉做出了如下解释

In order to integrate temporal information at different times, we can provide different frames with a unique interlacing offset. Instead of habitually assigning each channel with a separately learnable offset, we adopt distinctive offsets for different channel groups. As observed in SlowFast (Feichtenhofer et al. 2018), human perception on object motion focuses on different temporal resolutions. To maintain temporal fidelity and recognize spatial semantics jointly, different groups of temporal receptive fields pursuit a thorough separation of expertise convolution. Besides, groups of offsets also reduce the model complexity as well as stabilize the training procedure across heavy backbone architectures.(为了整合不同时间的时间信息,我们可以为不同的帧提供独特的交错偏移。 我们没有习惯性地为每个通道分配一个可单独学习的偏移量,而是为不同的通道组采用不同的偏移量。 正如在 SlowFast (Feichtenhofer et al. 2018) 中所观察到的,人类对物体运动的感知侧重于不同的时间分辨率。 为了保持时间保真度并共同识别空间语义,不同组的时间感受野追求专业卷积的彻底分离。 此外,偏移组还降低了模型的复杂性,并稳定了跨重型骨干架构的训练过程)

主要原理图如下所示:

Deformable Shift Module

时间交错框架如上图所示,如图(a) 展示(b)结构位于残差神经网络之前的结构。对于整个特征图,会将前 的通道保持不变,再会将 的通道进行分组,这里我们会分为4组,(两组沿着T维度的偏移量,剩下的两组是偏移量是这前两组的相反值, 这样做可以保证信息在时序维度上的流动是对称的,有利于 后续特征的融合)因为作者实验发现两组的效果是最好的, 这样每组对应不同偏移量。
Accuracies with different numbers of groups and reverse offsets

这些偏移量是怎么预测出来的呢?还是要对应上图的原理图, 对于输入的特征图我们首先会输入到3D平均池化网络,接着分别输入到OffsetNet网络以及WeightNet在将两者结合即可得到我们的偏移网络的特征图。OffsetNet主要负责预测偏移量而WeightNet主要负责预测融合后的时序维度上的特征权重。

如果原始输入是8帧,该网络便会为每组输出8个值分别代表每一帧的权重然后会直接用此值来加权融合过后每一帧的feature。我们也同时发现位于两端的帧所预测的权重大多会比较低,这里我们的猜想是两端的帧的特征在沿着时序移动时由于一边没有其他帧会损失掉一部分,因此导致了网络给他们一个较低的权重来弥补信息损失带来的影响。

可微模块的具体框架如下所示:


它可以将各组按channel维度切分出来的特征沿着时间维度移动任意个单位。其实现方式主要是通过一维线性差值实现的。其中我们还采用了时序扩展技术,以保证偏移之后位于视频之外的特征不为空。举个例子,原本位于T=0的特征在向前偏移0.5个单位后便位于T=-0.5的位置,该位置理论上是不存在特征的,但我们通过假设T=-1位置的特征全为0使位于-0.5的位置取到了特征,也即Feature(T=-0.5) = (Feature(T=-1) + Feature(T=0))。

1.1. Temporal-wise Frame Sampling

这里需要好好讲解Temporal-wise Frame Sampling, 该过程是一个线性插值的过程。

针对上面的描述,这边用一张图片来进行解释说明。


Temporal-wise Frame Sampling

1.2. Temporal Extension
Temporal Extension

部分特征可能被移出而变为0,进而在训练阶段损失梯度。输入范围是[1, T],为了减轻这个现象带来的影响,设置一个buffer来存储处于(0,1)与(T,T+1)间隔中被移出的特征。超出T+1与小于0的部分会被置0

1.3. Temporal Attention

关于这里的Temporal Attentation则是基于WeightNet生成的权重进行, 与OffsetNet进行组合。

二、 代码介绍


这里有关于数据及数据预处理可以参考前面的[mmaction2版本] 视频分类(一) TSM:Temporal Shift Module for Efficient Video Understanding 原理及代码讲解这篇博客进行理解。

2.1. 特征提取网络

本代码的特征提取网络是基于Resnet模型基础上进行改进。代码如下所示:

blocks = list(stage.children())
for i, b in enumerate(blocks):
    if i % n_round == 0:
        tds = TemporalInterlace(
                   b.conv1.in_channels,
                   num_segments=num_segments,
                   shift_div=shift_div)
        blocks[i].conv1.conv = CombineNet(tds, blocks[i].conv1.conv)
return nn.Sequential(*blocks)

self.layer1 = make_block_interlace(self.layer1, num_segment_list[0], self.shift_div)
self.layer2 = make_block_interlace(self.layer2, num_segment_list[1], self.shift_div)
self.layer3 = make_block_interlace(self.layer3, num_segment_list[2], self.shift_div)
self.layer4 = make_block_interlace(self.layer4, num_segment_list[3], self.shift_div)

我们先看下self.layer1

Sequential(
  (0): Bottleneck(
    (conv1): ConvModule(
      (conv): CombineNet(
        (net1): TemporalInterlace(
          (offset_net): OffsetNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(16, 1, kernel_size=(3,), stride=(1,), padding=(1,))
            (fc1): Linear(in_features=8, out_features=8, bias=True)
            (relu): ReLU()
            (fc2): Linear(in_features=8, out_features=2, bias=True)
          )
          (weight_net): WeightNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(16, 2, kernel_size=(3,), stride=(1,), padding=(1,))
          )
        )
        (net2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv2): ConvModule(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv3): ConvModule(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
    (downsample): ConvModule(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (1): Bottleneck(
    (conv1): ConvModule(
      (conv): CombineNet(
        (net1): TemporalInterlace(
          (offset_net): OffsetNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(64, 1, kernel_size=(3,), stride=(1,), padding=(1,))
            (fc1): Linear(in_features=8, out_features=8, bias=True)
            (relu): ReLU()
            (fc2): Linear(in_features=8, out_features=2, bias=True)
          )
          (weight_net): WeightNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(64, 2, kernel_size=(3,), stride=(1,), padding=(1,))
          )
        )
        (net2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv2): ConvModule(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv3): ConvModule(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
  )
  (2): Bottleneck(
    (conv1): ConvModule(
      (conv): CombineNet(
        (net1): TemporalInterlace(
          (offset_net): OffsetNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(64, 1, kernel_size=(3,), stride=(1,), padding=(1,))
            (fc1): Linear(in_features=8, out_features=8, bias=True)
            (relu): ReLU()
            (fc2): Linear(in_features=8, out_features=2, bias=True)
          )
          (weight_net): WeightNet(
            (sigmoid): Sigmoid()
            (conv): Conv1d(64, 2, kernel_size=(3,), stride=(1,), padding=(1,))
          )
        )
        (net2): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv2): ConvModule(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activate): ReLU(inplace=True)
    )
    (conv3): ConvModule(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (relu): ReLU(inplace=True)
  )
)
2.2 TemproalInterplace

执行代码如下:

class TemporalInterlace(nn.Module):
    """Temporal interlace module.

    This module is proposed in `Temporal Interlacing Network
    `_

    Args:
        in_channels (int): Channel num of input features.
        num_segments (int): Number of frame segments. Default: 3.
        shift_div (int): Number of division parts for shift. Default: 1.
    """

    def __init__(self, in_channels, num_segments=3, shift_div=1):
        super().__init__()
        self.num_segments = num_segments
        self.shift_div = shift_div
        self.in_channels = in_channels
        # hard code ``deform_groups`` according to original repo.
        self.deform_groups = 2

        self.offset_net = OffsetNet(in_channels // shift_div,
                                    self.deform_groups, num_segments)
        self.weight_net = WeightNet(in_channels // shift_div,
                                    self.deform_groups)

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The output of the module.
        """
        # x: [N, C, H, W],
        # where N = num_batches x num_segments, C = shift_div * num_folds
        n, c, h, w = x.size() # n=48 c=64, h=56, w=56
        #print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        #print(x.size())
        #print("#####################################################")
        num_batches = n // self.num_segments
        num_folds = c // self.shift_div # self.shift_div=4

        # x_out: [num_batches x num_segments, C, H, W]
        x_out = torch.zeros((n, c, h, w), device=x.device) # x_out shape=[48, 64, 56, 56]
        # x_descriptor: [num_batches, num_segments, num_folds, H, W]
        # num_folders=16
        x_descriptor = x[:, :num_folds, :, :].view(num_batches,
                                                   self.num_segments,
                                                   num_folds, h, w)
        # x_descriptor shape [6, 8, 16, 56, 56]

        # x should only obtain information on temporal and channel dimensions
        # x_pooled: [num_batches, num_segments, num_folds, W]
        x_pooled = torch.mean(x_descriptor, 3)
        # x_pooled: [num_batches, num_segments, num_folds]
        x_pooled = torch.mean(x_pooled, 3)
        # x_pooled: [num_batches, num_folds, num_segments]
        x_pooled = x_pooled.permute(0, 2, 1).contiguous()# x_pooled shape=[6, 16, 8]

        # Calculate weight and bias, here groups = 2
        # x_offset: [num_batches, groups]
        x_offset = self.offset_net(x_pooled).view(num_batches, -1) # x_offset shape [6, 2]
        # x_weight: [num_batches, num_segments, groups]
        x_weight = self.weight_net(x_pooled)

        # x_offset: [num_batches, 2 * groups]
        x_offset = torch.cat([x_offset, -x_offset], 1) # x_offset shape [6, 4]
        # x_shift: [num_batches, num_segments, num_folds, H, W]
        x_shift = linear_sampler(x_descriptor, x_offset)

        # x_weight: [num_batches, num_segments, groups, 1]
        x_weight = x_weight[:, :, :, None]
        # x_weight:
        # [num_batches, num_segments, groups * 2, c // self.shift_div // 4]
        x_weight = x_weight.repeat(1, 1, 2, num_folds // 2 // 2)
        # x_weight:
        # [num_batches, num_segments, c // self.shift_div = num_folds]
        x_weight = x_weight.view(x_weight.size(0), x_weight.size(1), -1)

        # x_weight: [num_batches, num_segments, num_folds, 1, 1]
        x_weight = x_weight[:, :, :, None, None]
        # x_shift: [num_batches, num_segments, num_folds, H, W]
        x_shift = x_shift * x_weight
        # x_shift: [num_batches, num_segments, num_folds, H, W]
        x_shift = x_shift.contiguous().view(n, num_folds, h, w)

        # x_out: [num_batches x num_segments, C, H, W]
        x_out[:, :num_folds, :] = x_shift
        x_out[:, num_folds:, :] = x[:, num_folds:, :]

        return x_out

首先输入x shape为[48, 3, 224, 224], 对应的含义分别是[batch_size, channel, height, width]。在进行convmax pool得到特征图大小为[48, 64, 56, 56]在输入到上述模型代码中。
n, c, h, w = x.size(), 这里的n=48, c=64, h=56, w=56, num_batches=6, num_folders=16, 再通过x_descriptor shape 为 [num_batches, num_segments, C, H, W](这里shape为[6, 8, 16, 56, 56])x_out shape为[48, 64, 56, 56]
根据论文中提到的公式

后面在通过求平均的方式torch.mean对空间信息进行平均信息压缩,如下面代码所示:

x_pooled = torch.mean(x_descriptor, 3)
x_pooled = torch.mean(x_pooled, 3)

我们得到x_pooledshape为[6, 16, 8], 之后将该结果输入到Offset Net

2.2.1 Offset Net

先给出代码

class OffsetNet(nn.Module):
    """OffsetNet in Temporal interlace module.

    The OffsetNet consists of one convolution layer and two fc layers
    with a relu activation following with a sigmoid function. Following
    the convolution layer, two fc layers and relu are applied to the output.
    Then, apply the sigmoid function with a multiply factor and a minus 0.5
    to transform the output to (-4, 4).

    Args:
        in_channels (int): Channel num of input features.
        groups (int): Number of groups for fc layer outputs.
        num_segments (int): Number of frame segments.
    """

    def __init__(self, in_channels, groups, num_segments):
        super().__init__()
        self.sigmoid = nn.Sigmoid()
        # hard code ``kernel_size`` and ``padding`` according to original repo.
        kernel_size = 3
        padding = 1

        self.conv = nn.Conv1d(in_channels, 1, kernel_size, padding=padding)
        self.fc1 = nn.Linear(num_segments, num_segments)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(num_segments, groups)

        self.init_weights()

    def init_weights(self):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        # The bias of the last fc layer is initialized to
        # make the post-sigmoid output start from 1
        self.fc2.bias.data[...] = 0.5108

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The output of the module.
        """
        # calculate offset
        # [N, C, T]
        # x shape=[6, 16, 8]
        n, _, t = x.shape # n=6, t=8
        # [N, 1, T]
        x = self.conv(x) # conv1d[16, 1], kernel_size=3  x shape=[6, 1, 8] 相当于在通道维度降维
        # [N, T]
        x = x.view(n, t) # x shape [6, 8]
        # [N, T]
        x = self.relu(self.fc1(x)) # fc1 [8, 8] x shape[6,8]
        # [N, groups]
        x = self.fc2(x) # fc2 shape [8, 2] x shape [6,2]
        # [N, 1, groups]
        x = x.view(n, 1, -1) # x shape [6, 1, 2]

        # to make sure the output is in (-t/2, t/2)
        # where t = num_segments = 8
        x = 4 * (self.sigmoid(x) - 0.5) # x shape [6, 1, 2]  t=8 so T= 8/2=4
        # [N, 1, groups]
        return x

根据论文中的公式:


首先对于输入x shape为[6 ,16, 8]通过fc1以及relu得到输出shape为[6, 8], 再将其输入到fc2网络中,这里的fc2的输出通道为2, 因为这里的group设置为2

        self.deform_groups = 2

所以输出shape为[6, 1, 2]。再经过如下公式:


这里我们设置T为4,即T=t/2(t=num_segments), 输出x的范围为[-2, 2]并且shape为[6, 1, 2] 对应的代码如下所示:

# to make sure the output is in (-t/2, t/2)
# where t = num_segments = 8
x = 4 * (self.sigmoid(x) - 0.5) # x shape [6, 1, 2]  t=8 so T= 8/2=4
# [N, 1, groups]
return x
2.2.2 Weight Net

同时我们将x并行输入到Weight Net, 首先先给出代码

class WeightNet(nn.Module):
    """WeightNet in Temporal interlace module.

    The WeightNet consists of two parts: one convolution layer
    and a sigmoid function. Following the convolution layer, the sigmoid
    function and rescale module can scale our output to the range (0, 2).
    Here we set the initial bias of the convolution layer to 0, and the
    final initial output will be 1.0.

    Args:
        in_channels (int): Channel num of input features.
        groups (int): Number of groups for fc layer outputs.
    """

    def __init__(self, in_channels, groups):
        super().__init__()
        self.sigmoid = nn.Sigmoid()
        self.groups = groups

        self.conv = nn.Conv1d(in_channels, groups, 3, padding=1)

        self.init_weights()

    def init_weights(self):
        """Initiate the parameters either from existing checkpoint or from
        scratch."""
        # we set the initial bias of the convolution
        # layer to 0, and the final initial output will be 1.0
        self.conv.bias.data[...] = 0

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (torch.Tensor): The input data.

        Returns:
            torch.Tensor: The output of the module.
        """
        # calculate weight
        # [N, C, T]
        # x shape=[6, 16, 8]
        n, _, t = x.shape
        # [N, groups, T]
        x = self.conv(x) # x shape [6, 2, 8]
        x = x.view(n, self.groups, t) # x shape [6, 2, 8]
        # [N, T, groups]
        x = x.permute(0, 2, 1) # x shape [6, 8, 2]

        # scale the output to range (0, 2)
        x = 2 * self.sigmoid(x)
        # [N, T, groups]
        return x

对应的最后x输出的shape为[6, 8, 2]范围是(0, 2)

2.2.3 Offset Net与 Weight Net结合
  1. 首先 x_offset = torch.cat([x_offset, -x_offset], 1)offset做对称。
  2. 再去取其权重及对应的特征
    x_shift = linear_sampler(x_descriptor, x_offset)
    具体代码如下
def linear_sampler(data, offset):
    """Differentiable Temporal-wise Frame Sampling, which is essentially a
    linear interpolation process.

    It gets the feature map which has been split into several groups
    and shift them by different offsets according to their groups.
    Then compute the weighted sum along with the temporal dimension.

    Args:
        data (torch.Tensor): Split data for certain group in shape
            [N, num_segments, C, H, W].
        offset (torch.Tensor): Data offsets for this group data in shape
            [N, num_segments].
    """
    # [N, num_segments, C, H, W]
    n, t, c, h, w = data.shape

    # offset0, offset1: [N, num_segments]
    offset0 = torch.floor(offset).int() # offset range [-2, 1]
    offset1 = offset0 + 1 # offset1 rang e[-1, 2] # 可以看出offset0 与offset1是对称左右移动

    # data, data0, data1: [N, num_segments, C, H * W]
    data = data.view(n, t, c, h * w).contiguous() # data shape [6, 8, 16, 3136]

    try:
        from mmcv.ops import tin_shift
    except (ImportError, ModuleNotFoundError):
        raise ImportError('Failed to import `tin_shift` from `mmcv.ops`. You '
                          'will be unable to use TIN. ')
    # data shape [6, 8, 16, 3136]
    data0 = tin_shift(data, offset0) # data0 shape [6, 8, 16, 3136]
    data1 = tin_shift(data, offset1)

    # weight0, weight1: [N, num_segments]
    weight0 = 1 - (offset - offset0.float())
    weight1 = 1 - weight0

    # weight0, weight1:
    # [N, num_segments] -> [N, num_segments, C // num_segments] -> [N, C]
    group_size = offset.shape[1]
    weight0 = weight0[:, :, None].repeat(1, 1, c // group_size)
    weight0 = weight0.view(weight0.size(0), -1)
    weight1 = weight1[:, :, None].repeat(1, 1, c // group_size)
    weight1 = weight1.view(weight1.size(0), -1)

    # weight0, weight1: [N, C] -> [N, 1, C, 1]
    weight0 = weight0[:, None, :, None]
    weight1 = weight1[:, None, :, None]

    # output: [N, num_segments, C, H * W] -> [N, num_segments, C, H, W]
    output = weight0 * data0 + weight1 * data1
    output = output.view(n, t, c, h, w)

    return output

代码中offset0对应图片中
上式output = weight0 * data0 + weight1 * data1 即反映了文章的精华。我们继续拿上面的这张图来解释这里的代码中weight0对应图中, weight1对应图中。
weight0 * data0 + weight1 * data1对应论文中


最终得到x_shift, 再加上weight Net得到的权重注意力相乘得到其结果x_shift = x_shift * x_weight。 这里的tin_shift原理可以看.cuh代码如下所示:

template 
__global__ void tin_shift_forward_cuda_kernel(
    const int nthreads, const T* input, const int* shift, T* output,
    const int batch_size, const int channels, const int t_size,
    const int hw_size, const int group_size, const int group_channel) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    const int hw_index = index % hw_size;
    const int j = (index / hw_size) % channels;

    const int n_index = (index / hw_size / channels) % batch_size;
    int group_id = j / group_channel;
    int t_shift = shift[n_index * group_size + group_id];
    int offset = n_index * t_size * hw_size * channels + hw_size * j + hw_index;
    for (int i = 0; i < t_size; i++) {
      int now_t = i + t_shift;
      int data_id = i * hw_size * channels + offset;
      if (now_t < 0 || now_t >= t_size) {
        continue;
      }
      int out_id = now_t * hw_size * channels + offset;
      output[out_id] = input[data_id];
    }
  }
}

剩下的部分就很简单了,和TSM原理类似, 这里就不作多余解释了。
总结下,据我的理解是相对于TSM,在时间上基于OffsetNet偏移量是可以训练的,再通过WeightNet可以给偏移量加权重,给更合适的偏移量更高的权重。有一个疑问就是为什么这边的偏移量范围是[-2, 2]的范围,我这里的理解是相对于TSM增大了时间维度的感受野,如果更大则很多信息溢出了T,导致无法获取,所以这边进行了平衡,如果有其他不同的观点欢迎提出。

参考资料

【1】MMIT冠军方案|用于行为识别的时间交错网络,商汤公开视频理解代码库

你可能感兴趣的:([mmaction2版本] 视频分类(二) TIN:Temporal Interlacing Network 原理及代码讲解)