MOT学习 - SORT算法

paper:https://arxiv.org/abs/1602.00763
code:https://github.com/open-mmlab/mmtracking
https://github.com/abewley/sort

摘要
SORT:Simple Online And Realtime Tracking,方法介绍

  • online(只考虑利用利用历史帧信息)+ realtime(很快,260HZ)
  • 方法是对卡尔曼滤波+匈牙利匹配进行组合

此外,提到了检测对于MOT的重要性,简单换一个检测器,MOT的性能可以提高18.9%左右。

ID创建&删除

  1. 当obj进入或者离开图像, unique id会被创建或者删除
  2. 创建: 任何重叠小于 IOUmin 的检测来表示未跟踪的对象。这时tracker初始化为: 检测框+对应的速度为0.
    PS:由于此时未观察到速度,因此将速度分量的协方差初始化为较大的值,反映了这种不确定性。
    此外, 创建的new tracker会经历一段试用期, 在这个期间, obj需要和检测结果关联足够多的帧(3帧), 以防止跟踪误报。
  3. 删除: Tracks如果TLost帧没有检测到就会被删除。
    这样做的目的: 一是为了防止trackers无限制的增大; 二是防止由于长时间预测而没有检测器校正而导致的定位错误(模型没检测到, 就用tracker的预测值)。
    在所有实验中, TLost设置为1。这样做的好处:首先,恒速模型不能很好地预测真实动力学;其次,我们主要关注帧到帧跟踪,其中对象重新识别超出了这项工作的范围。此外,早期删除丢失的目标有助于提高效率。 如果一个对象再次出现,跟踪将隐含地在一个新的身份下恢复.

调整的参数
卡尔曼滤波的协方差, IOUmin, Tlost。

代码阅读
https://github.com/abewley/sort/blob/bce9f0d1fc8fb5f45bf7084130248561a3d42f31/sort.py

import os
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from skimage import io

import glob
import time
import argparse
from filterpy.kalman import KalmanFilter

np.random.seed(0)


def linear_assignment(cost_matrix):
  try:
    import lap
    _, x, y = lap.lapjv(cost_matrix, extend_cost=True)
    return np.array([[y[i],i] for i in x if i >= 0]) #
  except ImportError:
    from scipy.optimize import linear_sum_assignment
    x, y = linear_sum_assignment(cost_matrix)
    return np.array(list(zip(x, y)))


def iou_batch(bb_test, bb_gt):
  """
  From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
  bb_test: mx4
  bb_gt: nx4
  return: mxn

  """
  bb_gt = np.expand_dims(bb_gt, 0)
  bb_test = np.expand_dims(bb_test, 1)

  xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
  yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
  xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
  yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
  w = np.maximum(0., xx2 - xx1)
  h = np.maximum(0., yy2 - yy1)
  wh = w * h
  o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
    + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
  return(o)


def convert_bbox_to_z(bbox):
  """
  Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
    [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
    the aspect ratio
  """
  w = bbox[2] - bbox[0]
  h = bbox[3] - bbox[1]
  x = bbox[0] + w/2.
  y = bbox[1] + h/2.
  s = w * h    #scale is just area
  r = w / float(h)
  return np.array([x, y, s, r]).reshape((4, 1))


def convert_x_to_bbox(x,score=None):
  """
  Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
    [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
  """
  w = np.sqrt(x[2] * x[3])
  h = x[2] / w
  if(score==None):
    return np.array([x[0]-w/2.,x[1]-h/2.,x[0]+w/2.,x[1]+h/2.]).reshape((1,4))
  else:
    return np.array([x[0]-w/2.,x[1]-h/2.,x[0]+w/2.,x[1]+h/2.,score]).reshape((1,5))


class KalmanBoxTracker(object):
  """
  This class represents the internal state of individual tracked objects observed as bbox.
  """
  count = 0
  def __init__(self,bbox):
    """
    Initialises a tracker using initial bounding box.
    """
    #define constant velocity model
    #dim_x = (u, v, s, r, u-, v-, s-)
    self.kf = KalmanFilter(dim_x=7, dim_z=4)
    self.kf.F = np.array([[1,0,0,0,1,0,0],[0,1,0,0,0,1,0],[0,0,1,0,0,0,1],[0,0,0,1,0,0,0],  [0,0,0,0,1,0,0],[0,0,0,0,0,1,0],[0,0,0,0,0,0,1]])
    self.kf.H = np.array([[1,0,0,0,0,0,0],[0,1,0,0,0,0,0],[0,0,1,0,0,0,0],[0,0,0,1,0,0,0]])

    self.kf.R[2:,2:] *= 10.  # s, r
    self.kf.P[4:,4:] *= 1000. #np.eye(7), give high uncertainty to the unobservable initial velocities
    self.kf.P *= 10.
    self.kf.Q[-1,-1] *= 0.01
    self.kf.Q[4:,4:] *= 0.01

    self.kf.x[:4] = convert_bbox_to_z(bbox)  # shape is (7, 1)
    self.time_since_update = 0  # 执行update以来, predict的次数, 可能存在多次predict才会update一次(predict非法&没被匹配到, 不会送去update)
    # time_since_update也可以记录被关联后, 累积的未被关联的次数(或者成为跟丢的次数)
    # time_since_update可作为Tracker正常进行了update的标志位
    self.id = KalmanBoxTracker.count # 0-based, 得到结果时候会+1, MOT评测时是1-based
    KalmanBoxTracker.count += 1  # 确保创建Tracker, id一直累加
    self.history = []
    self.hits = 0  # tracker累积命中det的次数
    self.hit_streak = 0  # tracker命中的次数(连续不断的),  tracker中途可能会断开若干帧
    self.age = 0  # 记录Tracker的kf做了多少次predict?

  def update(self,bbox):
    """
    Updates the state vector with observed bbox.
    pred预测合理 & 和det有关联, 才会进行update
    """
    self.time_since_update = 0
    self.history = []
    self.hits += 1
    self.hit_streak += 1
    self.kf.update(convert_bbox_to_z(bbox))

  def predict(self):
    """
    Advances the state vector and returns the predicted bounding box estimate.
    """
    if((self.kf.x[6]+self.kf.x[2])<=0):
      self.kf.x[6] *= 0.0  # x[6]+x[2]为预测的, pred_size = size + delta_size < 0, 则x[6]设置为0(delta_size预测过猛了?)
    self.kf.predict()
    self.age += 1
    if(self.time_since_update>0):
      # 没有update, hit_streak就被置0
      # 已经跟丢了, 放在这里貌似会延后一帧?
      self.hit_streak = 0
    self.time_since_update += 1
    self.history.append(convert_x_to_bbox(self.kf.x))
    return self.history[-1]

  def get_state(self):
    """
    Returns the current bounding box estimate.
    """
    return convert_x_to_bbox(self.kf.x)


def associate_detections_to_trackers(detections,trackers,iou_threshold = 0.3):
  """
  Assigns detections to tracked object (both represented as bounding boxes)

  Returns 3 lists of matches, unmatched_detections and unmatched_trackers
  """
  # detections: 模型检测结果(mx5)
  # trackers: tracker预测的结果(去除非法的结果), nx5
  # return: kx2, 0-axis为dets id, 1-axis为trackers id
  if(len(trackers)==0):
    return np.empty((0,2),dtype=int), np.arange(len(detections)), np.empty((0,5),dtype=int)

  iou_matrix = iou_batch(detections, trackers)  # mxn

  if min(iou_matrix.shape) > 0:
    a = (iou_matrix > iou_threshold).astype(np.int32)  # mxn, 0-1 matrix
    if a.sum(1).max() == 1 and a.sum(0).max() == 1:
        # mxn行和列都只有One2One匹配
        matched_indices = np.stack(np.where(a), axis=1)  # m -> n的索引矩阵, shape=kx2, k为匹配上的个数
    else:
      matched_indices = linear_assignment(-iou_matrix)  # 0-axis为dets, 1-axis为trackers
  else:
    matched_indices = np.empty(shape=(0,2))

  unmatched_detections = []
  for d, det in enumerate(detections):
    if(d not in matched_indices[:,0]):
      unmatched_detections.append(d)
  unmatched_trackers = []
  for t, trk in enumerate(trackers):
    if(t not in matched_indices[:,1]):
      unmatched_trackers.append(t)

  #filter out matched with low IOU
  #低于iou_threshold的匹配也被认为是不匹配, 加到unmatched_detections和unmatched_trackers
  matches = []

  for m in matched_indices:
    if(iou_matrix[m[0], m[1]]=min_hits OR 处于前三帧(sort初始化的时候)
        if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
          ret.append(np.concatenate((d, [trk.id+1])).reshape(1,-1)) # +1 as MOT benchmark requires positive
        i -= 1
        # remove dead tracklet
        # update(关联)之后, 连续跟丢次数 > max_age, 则丢弃该tracker, max_age=1
        if(trk.time_since_update > self.max_age):
          self.trackers.pop(i)
    if(len(ret)>0):
      return np.concatenate(ret)
    return np.empty((0,5))

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='SORT demo')
    parser.add_argument('--display', dest='display', help='Display online tracker output (slow) [False]',action='store_true')
    parser.add_argument("--seq_path", help="Path to detections.", type=str, default='data')
    parser.add_argument("--phase", help="Subdirectory in seq_path.", type=str, default='train')
    parser.add_argument("--max_age",
                        help="Maximum number of frames to keep alive a track without associated detections.",
                        type=int, default=1)
    parser.add_argument("--min_hits",
                        help="Minimum number of associated detections before track is initialised.",
                        type=int, default=3)
    parser.add_argument("--iou_threshold", help="Minimum IOU for match.", type=float, default=0.3)
    args = parser.parse_args()
    return args

if __name__ == '__main__':
  # all train
  args = parse_args()
  display = args.display
  phase = args.phase
  total_time = 0.0
  total_frames = 0
  colours = np.random.rand(32, 3) #used only for display
  if(display):
    if not os.path.exists('mot_benchmark'):
      print('\n\tERROR: mot_benchmark link not found!\n\n    Create a symbolic link to the MOT benchmark\n    (https://motchallenge.net/data/2D_MOT_2015/#download). E.g.:\n\n    $ ln -s /path/to/MOT2015_challenge/2DMOT2015 mot_benchmark\n\n')
      exit()
    plt.ion()
    fig = plt.figure()
    ax1 = fig.add_subplot(111, aspect='equal')

  if not os.path.exists('output'):
    os.makedirs('output')
  pattern = os.path.join(args.seq_path, phase, '*', 'det', 'det.txt')
  # pattern == 'data\\train\\*\\det\\det.txt'
  for seq_dets_fn in glob.glob(pattern):
    # seq_dets_fn == 'data\\train\\ADL-Rundle-6\\det\\det.txt'
    mot_tracker = Sort(max_age=args.max_age,
                       min_hits=args.min_hits,
                       iou_threshold=args.iou_threshold) #create instance of the SORT tracker
    seq_dets = np.loadtxt(seq_dets_fn, delimiter=',')
    seq = seq_dets_fn[pattern.find('*'):].split(os.path.sep)[0]

    with open(os.path.join('output', '%s.txt'%(seq)),'w') as out_file:
      print("Processing %s."%(seq))
      for frame in range(int(seq_dets[:,0].max())):
        frame += 1 #detection and frame numbers begin at 1
        dets = seq_dets[seq_dets[:, 0]==frame, 2:7]
        dets[:, 2:4] += dets[:, 0:2] #convert to [x1,y1,w,h] to [x1,y1,x2,y2]
        total_frames += 1

        if(display):
          fn = os.path.join('mot_benchmark', phase, seq, 'img1', '%06d.jpg'%(frame))
          im =io.imread(fn)
          ax1.imshow(im)
          plt.title(seq + ' Tracked Targets')

        start_time = time.time()
        import pdb
        pdb.set_trace()
        trackers = mot_tracker.update(dets)
        import pdb
        pdb.set_trace()
        cycle_time = time.time() - start_time
        total_time += cycle_time

        for d in trackers:
          print('%d,%d,%.2f,%.2f,%.2f,%.2f,1,-1,-1,-1'%(frame,d[4],d[0],d[1],d[2]-d[0],d[3]-d[1]),file=out_file)
          if(display):
            d = d.astype(np.int32)
            ax1.add_patch(patches.Rectangle((d[0],d[1]),d[2]-d[0],d[3]-d[1],fill=False,lw=3,ec=colours[d[4]%32,:]))

        if(display):
          fig.canvas.flush_events()
          plt.draw()
          ax1.cla()

  print("Total Tracking took: %.3f seconds for %d frames or %.1f FPS" % (total_time, total_frames, total_frames / total_time))

  if(display):
    print("Note: to get real runtime results run without the option: --display")

你可能感兴趣的:(多目标追踪,学习,算法,目标跟踪,SORT)