SORT-3 匈牙利算法和SORT类

SORT系列

SORT-1 项目配置运行-WINDOWS

SORT-2 SORT流程&卡尔曼滤波推导和底层代码详解

匈牙利算法

有权二部图的最大匹配。

讲解视频:二部图和匈牙利算法
利用匈牙利算法对目标框和检测框进行关联:匈牙利进行关联
API文档:Linear_sum_assignment
底层代码详解:Linear_sum_assignment源码详解

匈牙利算法在 SORT 中的使用: MOT - 数据关联

SORT类代码详解

流程图

SORT 类 Update 函数的流程图:
SORT-3 匈牙利算法和SORT类_第1张图片

几个细节的问题

参考:目标跟踪:yolov4目标检测 + sort目标跟踪
接下来就是逐帧怎样进行预测和更新;
1、怎样predict ?
2、怎样update ?
3、trk.hit_streak怎样实现连击 >= min_hits时,赋予该trk一个id ?
4、trk.time_since_update > max_age时,删除该轨迹trk的实现方式?

4个小问题精髓答案:

  1. update matched trackers with assigned detections
  2. 只有匹配上才会update,此时,每个track的time_since_update重新归置为0,
  3. 没有匹配上轨迹的detections,将会赋予新的轨迹。
  4. 新轨迹连续匹配上update时,hit_streak += 1,大于min_hits时才会赋予新的id;
    新轨迹未如果没匹配上detection,这个track的time_since_update+=1,重新将hit_streak = 0,直到这个trk.time_since_update>max_age, 删除该轨迹

代码详细注解

  def update(self, dets=np.empty((0, 5))):
    """
    Params:
      dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] 检测框列表
    Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
    Returns the a similar array, where the last column is the object ID.

    NOTE: The number of objects returned may differ from the number of detections provided.
    """
    self.frame_count += 1
    # get predicted locations from existing trackers.
    trks = np.zeros((len(self.trackers), 5))    # [[x1,y1,x2,y2,ID],...]
    to_del = []             # 待删除列表
    ret = []                # 个人理解为,可满足显示条件的跟踪器

    """
    1. 现有的跟踪器(列表),(全部)做一次预测
    """
    for t, trk in enumerate(trks):
      pos = self.trackers[t].predict()[0]           # 在已有的跟踪器上,做一次预测
      trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]  # 获得预测框的坐标
      if np.any(np.isnan(pos)):
        to_del.append(t)            # 若有非法值则加入 to_del

    # masked_invalid : 对掩码数组中的无效值做掩码处理;
    # ma.compress_rows : 抑制包含屏蔽值的二维数组的行和/或列,默认则抑制行+列
    trks = np.ma.compress_rows(np.ma.masked_invalid(trks))

    # 删除 to_del 待删跟踪器
    for t in reversed(to_del):
      self.trackers.pop(t)

    """
     2. 预测结果(预测框)和检测框(参数传入)做一次匹配
     """
    # 做一次 检测框 、预测框 关联
    matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks, self.iou_threshold)

    """
     3. 根据匹配结果,分别更新 matched / unmatched_detections / unmatched_trackers 三类
     """
    # update matched trackers with assigned detections
    # 更新匹配成功的跟踪器
    for m in matched:
      self.trackers[m[1]].update(dets[m[0], :])

    # create and initialise new trackers for unmatched detections
    # 为未匹配的检测框创建一个新的跟踪器
    for i in unmatched_dets:
        trk = KalmanBoxTracker(dets[i,:])
        self.trackers.append(trk)

    i = len(self.trackers)
    for trk in reversed(self.trackers):   # reversed : 返回一个反转的迭代器
        d = trk.get_state()[0]            # get_state() 返回 bbox

        # 判断该跟踪器是时间、匹配次数是否满足条件,是则加入 ret 中
        if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
          ret.append(np.concatenate((d,[trk.id+1])).reshape(1,-1)) # +1 as MOT benchmark requires positive
        i -= 1
        # remove dead tracklet 如果最大年龄超限,则删除该跟踪器
        if(trk.time_since_update > self.max_age):
          self.trackers.pop(i)

    if(len(ret)>0):
      return np.concatenate(ret)    # 拼接

    return np.empty((0,5))

主函数

主函数流程:
SORT-3 匈牙利算法和SORT类_第2张图片
代码详细注释

if __name__ == '__main__':
  """
  
  此例程使用现成的检测结果,保存在指定文件中;而不是实时获取检测框。
  
  """
  # all train
  args = parse_args()
  display = args.display
  phase = args.phase  # train / test
  total_time = 0.0
  total_frames = 0
  colours = np.random.rand(32, 3) #used only for display
  if(display):
    if not os.path.exists('mot_benchmark'):
      print('\n\tERROR: mot_benchmark link not found!\n\n    Create a symbolic link to the MOT benchmark\n    (https://motchallenge.net/data/2D_MOT_2015/#download). E.g.:\n\n    $ ln -s /path/to/MOT2015_challenge/2DMOT2015 mot_benchmark\n\n')
      exit()

    # 交互模式
    plt.ion()
    fig = plt.figure()
    ax1 = fig.add_subplot(111, aspect='equal')

  if not os.path.exists('output'):
    os.makedirs('output')


  pattern = os.path.join(args.seq_path, phase, '*', 'det', 'det.txt')     # path.join 路径拼接 eg:'.\kuai\train*det\det.txt'


  for seq_dets_fn in glob.glob(pattern):  # glob:获取指定路径下所有满足条件的文件路径名
    # 创建 Sort 对象 mot_tracker
    mot_tracker = Sort(max_age=args.max_age, 
                       min_hits=args.min_hits,
                       iou_threshold=args.iou_threshold) #create instance of the SORT tracker

    seq_dets = np.loadtxt(seq_dets_fn, delimiter=',')   # 按行读取一个文件 : 一行就是一个检测框! [x1,y1,x2,y2,score]

    seq = seq_dets_fn[pattern.find('*'):].split(os.path.sep)[0] # path.sep : 将路径按层切分成列表;获得路径下文件名,在这里就是编号
    
    with open(os.path.join('output', '%s.txt'%(seq)),'w') as out_file:  # 以写方式打开文件
      print("Processing %s."%(seq))
      for frame in range(int(seq_dets[:,0].max())):   # 找到该文件下第一列的最大值。第一列为帧序号,则为找到最大帧号

        frame += 1 #detection and frame numbers begin at 1
        dets = seq_dets[seq_dets[:, 0]==frame, 2:7]   # 取出该 frame 下的 2-7 列
        dets[:, 2:4] += dets[:, 0:2]                  #convert to [x1,y1,w,h] to [x1,y1,x2,y2]  ; 2:4 是 w,h,0:2 是 x1,y1
        total_frames += 1

        if(display):
          # 找到该帧对应的图片,并显示
          fn = os.path.join('mot_benchmark', phase, seq, 'img1', '%06d.jpg'%(frame))
          im =io.imread(fn)
          ax1.imshow(im)
          plt.title(seq + ' Tracked Targets')


        start_time = time.time()
        trackers = mot_tracker.update(dets)   # 用检测框 dets 做一次更新; 返回满足条件可显示的 trackers ([x1,y1,x2,y2,ID])
        cycle_time = time.time() - start_time
        total_time += cycle_time

        # 在画布上画出 trackers 对应的矩形框
        for d in trackers:
          print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1'%(frame,d[4],d[0],d[1],d[2]-d[0],d[3]-d[1]),file=out_file)
          if(display):
            d = d.astype(np.int32)
            ax1.add_patch(patches.Rectangle((d[0],d[1]),d[2]-d[0],d[3]-d[1],fill=False,lw=3,ec=colours[d[4]%32,:]))

        if(display):
          fig.canvas.flush_events()   # 更新画图
          plt.draw()
          ax1.cla()

  print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))

  if(display):
    print("Note: to get real runtime results run without the option: --display")

你可能感兴趣的:(MOT,算法,目标跟踪,计算机视觉)