该函数主要作用是将模型的预测输出进行处理,计算当前帧所检测到的目标对象与已有目标对象(前一帧所跟踪到的、一定时间内跟丢的、前一帧新出现的目标)之间的特征距离、iou距离,得到当前帧目标对象与已有目标对象的最佳匹配,然后根据目标对象的状态(跟住、跟丢、新出现、跟丢又出现、新出现但跟丢)进行相应的处理
self.frame_id += 1
activated_starcks = [] # activated_starcks在这段代码结束后包含三部分的对象:
# ①前一帧跟住且这一帧也跟住了的,②前一帧新出现这一帧跟住了,③这一帧新出现的对象
refind_stracks = [] # 存放前几帧跟丢之后,这一帧重新匹配上了的STrack对象
lost_stracks = [] # 用于存放这一帧跟丢的STrack对象
removed_stracks = [] # 存放:①上一帧新出现这一帧未匹配到的,②超过30帧跟丢的对象
t1 = time.time()
其中STrack对象是一个类,STrack类包含的属性代码如下:
with torch.no_grad():
pred = self.model(im_blob) #im_blob为输入图像
# pred is tensor of all the proposals (proposals默认大小为: 54264)
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
# 丢弃置信度低于阈值的proposals
pred经过model预测后,pred.shape=torch.Size([1, 54264, 518])
pred里面包括bounding box 和embeddings,经过筛选后pred.shape=torch.Size([2, 518])
if len(pred) > 0:
dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres,
self.opt.nms_thres)[0].cpu()
# 改变detection大小
scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
'''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)
class_pred就是embeddings.为512维的特征向量'''
detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
(tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
else:
detections = []
经过非极大值抑制之后,dets.shape=torch.Size([1, 518])
scale_coords的作用是缩放
最后遍历dets,将其封装为一个个的STrack对象,再放入detections中,即detections中存放的是当前图像帧中的STrack对象,detections=[OT_0_(0-0)]
t2 = time.time()
''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = [] # type: list[STrack]
for track in self.tracked_stracks:#self.tracked_stracks为上一帧的tracked_stracks集合
if not track.is_activated:
# 在处理当前帧时,以前没有激活的tracks会被添加到unconfirmed list中,也就是新目标
unconfirmed.append(track) # 如第2帧会进入这里
# print("Should not be here, in unconfirmed")
else: # 不是上一帧检测到的新目标
tracked_stracks.append(track)# 不是上一帧检测到的新目标track放到tracked_stracks
执行完后tracked_stracks=[OT_1_(1-3)],unconfirmed=[]
''' Step 2: 使用匈牙利算法进行匹配'''
# 将[activated_strack,lost_stracks]融合成pool_stracks,
# tracked_stracks为上一帧跟住的,self.lost_stracks为上几(这里设定的是30)帧跟丢的对象
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
# 卡尔曼算法预测pool_stracks的新的mean,covariance。
STrack.multi_predict(strack_pool, self.kalman_filter)
由于self.lost_stracks=[],此时strack_pool = [OT_1_(1-3)]
'''将上一帧跟住的,和前几(30)帧内跟丢的对象与当前帧的STrack对象计算距离进行匹配'''
# 计算之前的STrack(pool_stracks)和当前帧的STrack(detection)的距离矩阵
dists = matching.embedding_distance(strack_pool, detections)
# 加入运动模型,得到新的距离矩阵,并将大于阈值的矩阵赋值为np.inf
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
# 利用匈牙利算法将检测框和跟踪框进行匹配,得到matches(matches是能匹配的track和detection),
# u_track(为未匹配的tracker id), u_detection三元组。
# 这里是将上一帧跟住的,和前几(30)帧内跟丢的对象与当前帧的STrack对象进行匹配。
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
调用embedding_distance()函数后,该函数会计算出每个strack_pool中的对象与每个detections中的对象的特征距离,embedding_distance()函数如下:dists = matching.embedding_distance(strack_pool, detections)=[[0.05161539]]
调用fuse_motion()函数后,fuse_motion()函数如下:dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)=[[0.05067831]]
调用linear_assignment()函数,该函数根据特征距离矩阵,得到每个strack_pool中的对象与每个detections中的对象的最佳匹配,其中的关键函数:
ost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
# cost为最小开销,x为每行所匹配到的列索引,y为每列所匹配到的行索引
linear_assignment()完整函数如下,其中for ix, mx in enumerate(x):这一行代码,ix就是行的索引,mx为与该ix匹配的列索引
matches=[[0 0]]
u_track=[], u_detection=[]
for itracked, idet in matches:
# itracked是以前帧中STrack对象的id,idet是当前帧Strack对象的id
track = strack_pool[itracked]
det = detections[idet]
# track只有3种状态:Tracked、Lost、Removed
if track.state == TrackState.Tracked:# 匹配到了前一帧跟住的STrack对象
# If the track is active, 更新当前帧Strack对象信息给track
track.update(detections[idet], self.frame_id)
activated_starcks.append(track) # 将track添加到activated_starcks中
else: # 状态不是TrackState.Tracked就说明匹配到了前几帧跟丢的目标
# 如果该track没有激活, put the track in refind_stracks list
track.re_activate(det, self.frame_id, new_id=False)#就重新激活
refind_stracks.append(track)
activated_starcks=[OT_1_(1-4)]
''' Step 3: 剩下的,用IOU进行匹配'''
# u_detection是未匹配到的当前帧中的STrack对象
detections = [detections[i] for i in u_detection]
# detections是当前帧中未匹配到的STrack对象的list集合
r_tracked_stracks = [] # 这个容器存放之前存在、但当前帧中不存在的STrack对象
for i in u_track: # 遍历之前存在、但当前帧中未匹配到的STrack对象
if strack_pool[i].state == TrackState.Tracked:
r_tracked_stracks.append(strack_pool[i])
dists = matching.iou_distance(r_tracked_stracks, detections)#计算IOU距离
# 利用匈牙利算法将检测框和跟踪框进行匹配,得到matches(matches是能匹配的track和detection),
# u_track(为未匹配的tracker id), u_detection三元组。
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
dists = matching.iou_distance(r_tracked_stracks, detections)=[]
matches=u_track=u_detection= [] 为空是由于我这里不存在未匹配到的STrack对象
for itracked, idet in matches: # 遍历使用iou匹配方式匹配上的
track = r_tracked_stracks[itracked]
det = detections[idet]
if track.state == TrackState.Tracked: # 上一帧跟住了的STrack对象
track.update(det, self.frame_id) # 更新当前帧Strack对象信息给track
activated_starcks.append(track) # 将track添加到activated_starcks中
else: # 说明匹配到了前几帧跟丢的对象
# 就给该目标更新fram_id等信息,对象一旦被激活,即便跟丢is_activated也是True
track.re_activate(det, self.frame_id, new_id=False)
refind_stracks.append(track) # 重新激活的track要放到refind_stracks中
这一步是说,如果使用iou匹配能匹配到的话,就遍历matches匹配矩阵,更新STrack对象的信息
for it in u_track: # 此时的u_track为匈牙利算法和iou匹配均未匹配到的STrack对象
track = r_tracked_stracks[it]
if not track.state == TrackState.Lost:
# 前几帧均跟丢的状态肯定为Lost,上一帧有但这一帧跟丢的目标不是Lost
track.mark_lost() # 将未匹配成功的track的状态改为lost
lost_stracks.append(track) # 将该track放到lost_stracks中
此时的u_track为匈牙利算法和iou匹配均未匹配到的STrack对象(之前存在、但当前帧中不存在的STrack对象),此时将未匹配成功的track的改为丢失状态,并将该track放入lost_stracks集合中,mark_lost()函数中的代码为:
detections = [detections[i] for i in u_detection]# 找出当前帧中未匹配到的STrack对象
# unconfirmed存放的是上一帧新检测到的STrack对象
# 如第2帧时,会将第1帧的STrack对象放到unconfirmed
dists = matching.iou_distance(unconfirmed, detections) #计算iou距离
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
unconfirmed存放的是上一帧新检测到的STrack对象,计算当前帧中未匹配成功的STrack对象与上一帧新检测到的STrack对象的iou距离
然后调用匈牙利匹配算法,matches中存放匹配到的STrack对象的索引,u_unconfirmed存放的是上一帧新检测到的且当前帧中未匹配到的STrack对象,u_detection为当前帧中未匹配到的STrack对象
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_id)
# 修改了状态self.state = TrackState.Tracked,is_activated = True
activated_starcks.append(unconfirmed[itracked]) # 将其添加到activated_starcks中
这一步会设置上一帧新出现的且这一帧匹配上了的STrack对象的is_activated = True
for it in u_unconfirmed: # 遍历上一帧新检测到的且当前帧中未匹配到的
track = unconfirmed[it]
track.mark_removed() # 修改track的状态为TrackState.Removed
removed_stracks.append(track) #将该track添加到removed_stracks中
""" Step 4: Init new stracks"""
for inew in u_detection:#之前未有的新目标
track = detections[inew]
if track.score < self.det_thresh:# 如果该track的得分低于阈值,则遍历下一个
continue
track.activate(self.kalman_filter, self.frame_id)# 如果得分≥阈值,激活该track
# 这里激活后的track.is_activated不是True
activated_starcks.append(track) #将该track添加到activated_starcks中
u_detection中存放的是当前帧中未匹配到的STrack对象,这一步是对新出现的目标分配id
self.det_thresh=0.5
""" Step 5: Update state"""
# 如果跟丢的目标超过间隔帧阈值, 就将其移除.
for track in self.lost_stracks: #遍历跟丢的STrack对象
if self.frame_id - track.end_frame > self.max_time_lost:
track.mark_removed() # 修改track的状态为TrackState.Removed
removed_stracks.append(track) # 将该track添加到removed_stracks中
self.max_time_lost=30,即丢失的STrack对象所在帧与当前帧索引相差30,则修改其状态并添加到removed_stracks列表中
# Update the self.tracked_stracks and self.lost_stracks using the updates
self.tracked_stracks = [t for t in self.tracked_stracks if t.state ==
TrackState.Tracked]
#第一帧时只有activated_starcks
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost]
# type: list[STrack]
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
self.tracked_stracks, self.lost_stracks)
self.tracked_stracks包含三个部分的STrack对象:
①上一帧跟住且这一帧也为跟住的对象(self.tracked_stracks if t.state==TrackState.Tracked)
②activated_starcks,activated_starcks当中又包含了三部分的对象:前一帧跟住且这一帧也跟住了的、前一帧新出现这一帧跟住了、这一帧新出现的对象STrack对象
③前几帧跟丢但这一帧跟上了的STrack对象(refind_stracks)
self.tracked_stracks=OT_1_(1-4)
self.lost_stracks =[]
self.removed_stracks=[]
然后再移除掉self.tracked_stracks、self.lost_stracks列表中重复的对象
output_stracks = [track for track in self.tracked_stracks if track.is_activated]
logger.debug('===========Frame {}=========='.format(self.frame_id))
logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
# print('Final {} s'.format(t5-t4))
return output_stracks
最终返回的output_stracks=[OT_1_(1-4)]里面是当前帧中所有跟住的STrack对象