Towards-Realtime-MOT源代码学习之JDETracker中的update()函数

该函数主要作用是将模型的预测输出进行处理,计算当前帧所检测到的目标对象与已有目标对象(前一帧所跟踪到的、一定时间内跟丢的、前一帧新出现的目标)之间的特征距离、iou距离,得到当前帧目标对象与已有目标对象的最佳匹配,然后根据目标对象的状态(跟住、跟丢、新出现、跟丢又出现、新出现但跟丢)进行相应的处理

self.frame_id += 1
activated_starcks = []     # activated_starcks在这段代码结束后包含三部分的对象:
    # ①前一帧跟住且这一帧也跟住了的,②前一帧新出现这一帧跟住了,③这一帧新出现的对象
refind_stracks = []        # 存放前几帧跟丢之后,这一帧重新匹配上了的STrack对象
lost_stracks = []          # 用于存放这一帧跟丢的STrack对象
removed_stracks = []       # 存放:①上一帧新出现这一帧未匹配到的,②超过30帧跟丢的对象

t1 = time.time()

其中STrack对象是一个类,STrack类包含的属性代码如下:Towards-Realtime-MOT源代码学习之JDETracker中的update()函数_第1张图片

with torch.no_grad():
    pred = self.model(im_blob) #im_blob为输入图像
    # pred is tensor of all the proposals (proposals默认大小为: 54264)
pred = pred[pred[:, :, 4] > self.opt.conf_thres]
# 丢弃置信度低于阈值的proposals

pred经过model预测后,pred.shape=torch.Size([1, 54264, 518])
pred里面包括bounding box 和embeddings,经过筛选后pred.shape=torch.Size([2, 518])

if len(pred) > 0:
    dets = non_max_suppression(pred.unsqueeze(0), self.opt.conf_thres, 
           self.opt.nms_thres)[0].cpu()
    # 改变detection大小
    scale_coords(self.opt.img_size, dets[:, :4], img0.shape).round()
    '''Detections is list of (x1, y1, x2, y2, object_conf, class_score, class_pred)
    class_pred就是embeddings.为512维的特征向量'''
    detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f.numpy(), 30) for
                          (tlbrs, f) in zip(dets[:, :5], dets[:, 6:])]
else:
    detections = []

经过非极大值抑制之后,dets.shape=torch.Size([1, 518])
scale_coords的作用是缩放
最后遍历dets,将其封装为一个个的STrack对象,再放入detections中,即detections中存放的是当前图像帧中的STrack对象,detections=[OT_0_(0-0)]

t2 = time.time()

''' Add newly detected tracklets to tracked_stracks'''
unconfirmed = []
tracked_stracks = []  # type: list[STrack]
for track in self.tracked_stracks:#self.tracked_stracks为上一帧的tracked_stracks集合
    if not track.is_activated:
        # 在处理当前帧时,以前没有激活的tracks会被添加到unconfirmed list中,也就是新目标
        unconfirmed.append(track) # 如第2帧会进入这里
        # print("Should not be here, in unconfirmed")
    else: # 不是上一帧检测到的新目标
        tracked_stracks.append(track)# 不是上一帧检测到的新目标track放到tracked_stracks

执行完后tracked_stracks=[OT_1_(1-3)],unconfirmed=[]

''' Step 2: 使用匈牙利算法进行匹配'''
# 将[activated_strack,lost_stracks]融合成pool_stracks,
# tracked_stracks为上一帧跟住的,self.lost_stracks为上几(这里设定的是30)帧跟丢的对象
strack_pool = joint_stracks(tracked_stracks, self.lost_stracks) 
# 卡尔曼算法预测pool_stracks的新的mean,covariance。
STrack.multi_predict(strack_pool, self.kalman_filter)

由于self.lost_stracks=[],此时strack_pool = [OT_1_(1-3)]

'''将上一帧跟住的,和前几(30)帧内跟丢的对象与当前帧的STrack对象计算距离进行匹配'''
# 计算之前的STrack(pool_stracks)和当前帧的STrack(detection)的距离矩阵
dists = matching.embedding_distance(strack_pool, detections)
# 加入运动模型,得到新的距离矩阵,并将大于阈值的矩阵赋值为np.inf
dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
# 利用匈牙利算法将检测框和跟踪框进行匹配,得到matches(matches是能匹配的track和detection), 
# u_track(为未匹配的tracker id), u_detection三元组。
# 这里是将上一帧跟住的,和前几(30)帧内跟丢的对象与当前帧的STrack对象进行匹配。
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)

调用embedding_distance()函数后,该函数会计算出每个strack_pool中的对象与每个detections中的对象的特征距离,embedding_distance()函数如下:Towards-Realtime-MOT源代码学习之JDETracker中的update()函数_第2张图片dists = matching.embedding_distance(strack_pool, detections)=[[0.05161539]]
调用fuse_motion()函数后,fuse_motion()函数如下:Towards-Realtime-MOT源代码学习之JDETracker中的update()函数_第3张图片dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)=[[0.05067831]]

调用linear_assignment()函数,该函数根据特征距离矩阵,得到每个strack_pool中的对象与每个detections中的对象的最佳匹配,其中的关键函数:

ost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
# cost为最小开销,x为每行所匹配到的列索引,y为每列所匹配到的行索引

linear_assignment()完整函数如下,其中for ix, mx in enumerate(x):这一行代码,ix就是行的索引,mx为与该ix匹配的列索引Towards-Realtime-MOT源代码学习之JDETracker中的update()函数_第4张图片
matches=[[0 0]]
u_track=[], u_detection=[]  

for itracked, idet in matches:
    # itracked是以前帧中STrack对象的id,idet是当前帧Strack对象的id
    track = strack_pool[itracked]
    det = detections[idet]
    # track只有3种状态:Tracked、Lost、Removed
    if track.state == TrackState.Tracked:# 匹配到了前一帧跟住的STrack对象
        # If the track is active, 更新当前帧Strack对象信息给track
        track.update(detections[idet], self.frame_id)
        activated_starcks.append(track) # 将track添加到activated_starcks中
    else: # 状态不是TrackState.Tracked就说明匹配到了前几帧跟丢的目标
        # 如果该track没有激活, put the track in refind_stracks list
        track.re_activate(det, self.frame_id, new_id=False)#就重新激活
        refind_stracks.append(track)

调用update函数后,该函数完整变量数值如下:Towards-Realtime-MOT源代码学习之JDETracker中的update()函数_第5张图片

 activated_starcks=[OT_1_(1-4)]

''' Step 3: 剩下的,用IOU进行匹配'''
# u_detection是未匹配到的当前帧中的STrack对象
detections = [detections[i] for i in u_detection]
# detections是当前帧中未匹配到的STrack对象的list集合
r_tracked_stracks = [] # 这个容器存放之前存在、但当前帧中不存在的STrack对象

for i in u_track: # 遍历之前存在、但当前帧中未匹配到的STrack对象
    if strack_pool[i].state == TrackState.Tracked:
        r_tracked_stracks.append(strack_pool[i])
dists = matching.iou_distance(r_tracked_stracks, detections)#计算IOU距离
# 利用匈牙利算法将检测框和跟踪框进行匹配,得到matches(matches是能匹配的track和detection), 
# u_track(为未匹配的tracker id), u_detection三元组。
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)

dists = matching.iou_distance(r_tracked_stracks, detections)=[]
matches=u_track=u_detection= [] 为空是由于我这里不存在未匹配到的STrack对象

for itracked, idet in matches: # 遍历使用iou匹配方式匹配上的
    track = r_tracked_stracks[itracked]
    det = detections[idet]
    if track.state == TrackState.Tracked: # 上一帧跟住了的STrack对象
        track.update(det, self.frame_id) # 更新当前帧Strack对象信息给track
        activated_starcks.append(track) # 将track添加到activated_starcks中
    else:  # 说明匹配到了前几帧跟丢的对象
        # 就给该目标更新fram_id等信息,对象一旦被激活,即便跟丢is_activated也是True
        track.re_activate(det, self.frame_id, new_id=False)
        refind_stracks.append(track) # 重新激活的track要放到refind_stracks中

这一步是说,如果使用iou匹配能匹配到的话,就遍历matches匹配矩阵,更新STrack对象的信息

for it in u_track: # 此时的u_track为匈牙利算法和iou匹配均未匹配到的STrack对象
    track = r_tracked_stracks[it]
        if not track.state == TrackState.Lost:
        # 前几帧均跟丢的状态肯定为Lost,上一帧有但这一帧跟丢的目标不是Lost
            track.mark_lost() # 将未匹配成功的track的状态改为lost
            lost_stracks.append(track) # 将该track放到lost_stracks中

此时的u_track为匈牙利算法和iou匹配均未匹配到的STrack对象(之前存在、但当前帧中不存在的STrack对象),此时将未匹配成功的track的改为丢失状态,并将该track放入lost_stracks集合中,mark_lost()函数中的代码为:
 

detections = [detections[i] for i in u_detection]# 找出当前帧中未匹配到的STrack对象
# unconfirmed存放的是上一帧新检测到的STrack对象
# 如第2帧时,会将第1帧的STrack对象放到unconfirmed
dists = matching.iou_distance(unconfirmed, detections) #计算iou距离
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)

unconfirmed存放的是上一帧新检测到的STrack对象,计算当前帧中未匹配成功的STrack对象与上一帧新检测到的STrack对象的iou距离
然后调用匈牙利匹配算法,matches中存放匹配到的STrack对象的索引,u_unconfirmed存放的是上一帧新检测到的且当前帧中未匹配到的STrack对象,u_detection为当前帧中未匹配到的STrack对象
 

for itracked, idet in matches:
    unconfirmed[itracked].update(detections[idet], self.frame_id)
    # 修改了状态self.state = TrackState.Tracked,is_activated = True
    activated_starcks.append(unconfirmed[itracked]) # 将其添加到activated_starcks中

这一步会设置上一帧新出现的且这一帧匹配上了的STrack对象的is_activated = True 

for it in u_unconfirmed: # 遍历上一帧新检测到的且当前帧中未匹配到的
    track = unconfirmed[it]
    track.mark_removed() # 修改track的状态为TrackState.Removed
    removed_stracks.append(track) #将该track添加到removed_stracks中

mark_removed()函数中的代码为:

""" Step 4: Init new stracks"""
for inew in u_detection:#之前未有的新目标
    track = detections[inew]
    if track.score < self.det_thresh:# 如果该track的得分低于阈值,则遍历下一个
        continue
    track.activate(self.kalman_filter, self.frame_id)# 如果得分≥阈值,激活该track
    # 这里激活后的track.is_activated不是True
    activated_starcks.append(track) #将该track添加到activated_starcks中

u_detection中存放的是当前帧中未匹配到的STrack对象,这一步是对新出现的目标分配id
self.det_thresh=0.5

""" Step 5: Update state"""
# 如果跟丢的目标超过间隔帧阈值, 就将其移除.
for track in self.lost_stracks: #遍历跟丢的STrack对象
    if self.frame_id - track.end_frame > self.max_time_lost:
        track.mark_removed() # 修改track的状态为TrackState.Removed
        removed_stracks.append(track) # 将该track添加到removed_stracks中

self.max_time_lost=30,即丢失的STrack对象所在帧与当前帧索引相差30,则修改其状态并添加到removed_stracks列表中

# Update the self.tracked_stracks and self.lost_stracks using the updates
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == 
                                TrackState.Tracked]
#第一帧时只有activated_starcks
self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
# self.lost_stracks = [t for t in self.lost_stracks if t.state == TrackState.Lost]
# type: list[STrack]
self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
            self.tracked_stracks, self.lost_stracks)

self.tracked_stracks包含三个部分的STrack对象:
        ①上一帧跟住且这一帧也为跟住的对象(self.tracked_stracks if t.state==TrackState.Tracked)
        ②activated_starcks,activated_starcks当中又包含了三部分的对象:前一帧跟住且这一帧也跟住了的、前一帧新出现这一帧跟住了、这一帧新出现的对象STrack对象
        ③前几帧跟丢但这一帧跟上了的STrack对象(refind_stracks)
self.tracked_stracks=OT_1_(1-4)
self.lost_stracks =[]
self.removed_stracks=[]
然后再移除掉self.tracked_stracks、self.lost_stracks列表中重复的对象

output_stracks = [track for track in self.tracked_stracks if track.is_activated]

logger.debug('===========Frame {}=========='.format(self.frame_id))
logger.debug('Activated: {}'.format([track.track_id for track in activated_starcks]))
logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
# print('Final {} s'.format(t5-t4))
return output_stracks

最终返回的output_stracks=[OT_1_(1-4)]里面是当前帧中所有跟住的STrack对象

你可能感兴趣的:(目标跟踪,计算机视觉,深度学习)