CVPR2023 Autoregressive Visual Tracking 理解记录

ARTrack code with comments
https://github.com/MIV-XJTU/ARTrack
ARTrack的框架:
CVPR2023 Autoregressive Visual Tracking 理解记录_第1张图片

代码训练主要分为两阶段:

  • 第一阶段就是和seqtrack是一样的,就是template和search的图像打成patches送进transformer的encoder和decoder,只不过decoder这里送的query tokens送的是一个[cmd]或者[start] token,然后加x,y,w token, 序列化的顺序预测x,y,w,h目标位置信息,因为在预测x的时候只知道[start] token, 预测y的时候只知道[start] token和x token, 以此类推,所以motivation里面常写,如果模型知道目标在哪里,就能给一个命令就依次把目标的位置读出来。所以并不是给真值预测真值的看似白痴的学习。这里有两点需要注意的:

    • 这里的坐标变成token会经过一个word to embedding的过程,实现上来看就是把坐标当做index索引,会有一个embedding vocabulary字典被索引,经过坐标的索引出来的嵌入才会送入decoder里面。这样做的好处论文中解释为: This novel regression avoids direct non-linear mapping from image features to coordinates, which is often difficult.
    • 在训练阶段为了训练效率,通常会直接使用nn.MultiheadAttentionattn_mask来进行掩蔽,也就是预测当前的token的时候只能看见之前的token, 因为训练阶段如果也是一个一个顺序预测,训练效率就会很低,这也就是为什么ARTrack采用了第一阶段训练个比较可以baseline了,然后加上训练效率比较低的第二阶段,这样能够保证总体训练效率和每阶段加上都能有提升
      CVPR2023 Autoregressive Visual Tracking 理解记录_第2张图片
      ps: 可能去看第一阶段代码的时候会碰到下面这个
    	gt_bbox = gt_bbox.clamp(min=(-1*magic_num), max=(1+magic_num))
        seq_ori = (gt_bbox + magic_num) * (self.bins - 1)
    

    第一句话就是论文里面介绍的坐标映射到全局及夹限(虽然第一阶段用不上,但是为了衔接对齐第二阶段,第一阶段就需要一样的训练),本来坐标在[0.0, 1.0]现在就能够囊括[-0.5, 1.5],能够cover更多相邻帧中的目标坐标
    第二句话加了一个magic_num使之不会为负的,这样很巧妙的和range*bins对应上了,这个在预测完需要减回来

  • 至此第一阶段就是利用当前search的y,w,h预测x,y,w,h(或者x,y,x,y坐标体系也一样的),非常好理解,主要就是第二阶段的sequence-level 训练比较难以理解,还得进一步看代码。第二阶段重新实现了一个dataset和dataloader,为了能够获取一个视频底下的多帧search frame。

  • 可以看一下第二阶段的Actor,这里代码中主要用到的是explorecompute_sequence_losses两个函数,

    • 其中explore主要进行的是batch_init和batch_track,也就是输入一个batch的数据(batch中的单个元素是1张template和32张search frames【这32张是连续的,通过_sequential_sample采样得到,不过会和template有所间隔,也就是第一张search不是和template紧紧相连的】,而且这里是未进行crop成128/256的,因为crop会在这个batch_initbatch_track里面完成),由第一阶段训好的模型进行推理得到每张search下的结果,所以你会在这里看到本该在推理时看到的_bbox_clip和get_subwindow函数。explore返回的是一个字典,包含着template_images、search_images、search_anno、pre_seq、x_feat、baseline_iou等键和其对应的值。这些结果的含义分别是
      • template_images:torch.Tensor [B,3,128,128]就是根据img和template_bbox裁剪下来的模板
      • search_images: torch.Tensor [num_frames-1,B,3,128,128]就是每次根据前一帧预测结果crop下来的所有序列下的多张搜索区域
      • search_anno: torch.Tensor [num_frames-1,B,4]是gt bbox在每次crop下来的搜索区域中的坐标,形式为x,y,w,h,不是归一化的。因为这次训练阶段模拟的推理,所以gt坐标我们是知道的
      • pre_seq: torch.Tensor [num_frames-1,B,4✖pre_num]由每次batch_track获得的[B,4*pre_num] stack而成的,就是相当于全图(全局坐标系,我们只能假设相机是不动的)下的之前帧的坐标的一个缓存队列,是所有之前邻域内在raw image水平上的坐标映射到当前搜索区域里面的归一化坐标(这个映射主要靠与当前center_sz相减来完成的,其实就是transform_image_to_crop)
      • x_feat: torch.Tensor [num_frames-1,B,N,C]他实现的时候在decoder里面还有个encoder,就是这里的输出的搜索区域part的特征,N是token数量,C一般是768
      • baseline_iou:torch.Tensor [num_frames-1,B] 预测框和真值框的IoU,可以作为一个指标监控第二阶段初始阶段是不是有个较高的指标
        这上面都减1的原因是num_frames张search frames的头一张被用来确定目标首帧的位置了,相当于确定初始要跟踪的帧。
        CVPR2023 Autoregressive Visual Tracking 理解记录_第3张图片
        这里对batch_track进行总结一下,总体还是z_crop和x_crop送入encoder和decoder进行特征提取和交互,最后在decoder的query input的时候是输入序列化的坐标token,因为这种序列化的形式,就能够灵活地变长度,可以把之前pre_num帧的坐标囊括进来,提供了一种motion cues。但是当前坐标预测的时候是没有用到之前的search特征的。其实decoder包括self-attn和cross-attn,其中self-attn会用到causal attention mask,这个无论在train/test phase都会用到,只不过test阶段要根据前一个坐标推下一个坐标,是自回归的(train的时候本质也是自回归,但是有了真值和causal mask就显得是parallel了),所以会循环4次(首次输入的是[start] token)。
    • 其中compute_sequence_losses就是计算第二阶段训练的损失,好进行反向传播,并且梯度更新。下面这幅图显示的就是调用的时候,可以看到是以一个序列下num_frames-1张进行一次损失的计算。
      CVPR2023 Autoregressive Visual Tracking 理解记录_第4张图片
      具体看下图可知里面损失的计算包括一个cross_entropy loss和siou loss。具体就是把这同一个序列下连续的num_frames-1张search region都再和template做相关,然后用前面的7帧的历史坐标来推断当前的坐标,用search_anno也就是gt_in_crop来监督预测出来的坐标。
      所以回顾一下第二阶段的训练流程: 其实就是拿num_frames张search进行推理,得到推理的结果,以此我们就有了序列历史帧的坐标了,以历史帧的坐标为基础再前向推理一下,但这时候是训练(因为用上了真值进行损失计算梯度回传),所以这样保持了训练和推理的一致,训练的目标就是使推理这连续的num_frames-1张结果更准,就是一步一步在这个推理训练推理训练的循环中使得跟踪器更加拟合一个video clip的跟踪结果。其实compute_sequence_losses里面的前向传播和batch_track里面的很像,仔细想就是唯一的区别就是:现在计算损失的前向传播是每帧都已经有了前7帧历史坐标,可以并行化的自回归出结果。并且有真值进行监督。实际代码里的num_frames比论文中宣称的16要更长,为36。
      CVPR2023 Autoregressive Visual Tracking 理解记录_第5张图片

下面也就两个阶段的测试进行讲解:

  • 第一阶段的推理也是和seqtrack是一样的
  • 第二阶段的推理就是会有一个队列存储7帧的历史帧坐标,用来在decoder端作为和[start token]一起推出当前坐标的query。

一些pytorch语法

  1. nn.Module实例化的model设置training为True或者False,并不影响这个模型parameters的requires_grad
  2. 可以多次loss.backward()然后进行一次optimizer.step(),这就和gradient accumulation是类似的。

你可能感兴趣的:(目标跟踪,计算机视觉,目标跟踪)