目标跟踪:yolov4目标检测 + sort目标跟踪

本文主要是讲yolov4目标检测 + 目标跟踪sort整体的处理过程:

github地址: https://github.com/Deepcong2019/yolov4_sort

1、基于yolov4进行目标检测;

2、基于sort对目标检测的结果进行跟踪:

2.1 在获取到跟踪类别预测框的第一帧,通过匈牙利匹配结果产生unmatched_dets,给每个目标框分配一个轨迹和id;详细过程:https://blog.csdn.net/DeepCBW/article/details/124740092

2.2 接下来就是逐帧怎样进行预测和更新;

1、怎样predict ?
2、怎样update ?
3、trk.hit_streak怎样实现连击 >= min_hits时,赋予该trk一个id ?
4、trk.time_since_update > max_age时,删除该轨迹trk的实现方式?

2.2 中的4个小问题精髓答案:
# update matched trackers with assigned detections
# 只有匹配上才会update,此时,每个track的time_since_update重新归置为0,
# 没有匹配上轨迹的detections,将会赋予新的轨迹。
# 新轨迹连续匹配上update时,hit_streak += 1,大于min_hits时才会赋予新的id。
# 新轨迹未如果没匹配上detection,这个track的time_since_update+=1,重新将hit_streak = 0,直到这个trk.time_since_update>max_age, 删除该轨迹。
卡尔曼的详细设计过程:https://blog.csdn.net/DeepCBW/article/details/124759547
先贴上yolov4目标检测 + 目标跟踪sort整体过程的处理代码:

# -*- coding: utf-8 -*-
"""
Time    : 2022/5/12 08:35
Author  : cong
"""
from sort_class import *
from PIL import Image
from yolo_slim import *
import numpy as np
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
import re
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
yolo = YOLO()

# ------------------------------------------------------------------------
video_path = "video/310_r11.mp4"
video_save_path = re.split('.mp4', video_path)[0] + '_result.mp4'
# ------------------------------------------------------------------------

video_fps = 25
capture = cv2.VideoCapture(video_path)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
total_frame = capture.get(7)
print('total_frame:', total_frame)
ref, frame = capture.read()
if not ref:
    raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
count_id = 0

max_age = 100
min_hits = 1
iou_threshold = 0.3
trackers = []
frame_count = 0

while ref:
    ref, frame = capture.read()
    if not ref:
        break
    count_id += 1
    print('count_id:', count_id)
    # -------------------------------------------------------------------------#
    # 图像进行处理并检测
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = Image.fromarray(np.uint8(frame))
    hatch, cargo = yolo.detect_image(frame)
    frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
    frame_count += 1

    # 每一帧输出结果,开始跟踪过程
    # get predicted locations from existing trackers.
    if cargo:
        dets = np.array(cargo)
        # get predicted locations from existing trackers.
        trks = np.zeros((len(trackers), 5))
        to_del = []
        ret = []
        for t, trk in enumerate(trks):
            pos = trackers[t].predict()[0]
            trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
            cv2.rectangle(frame, (int(pos[1])-2, int(pos[0])-2), (int(pos[3])+2, int(pos[2])+2), (0, 255, 0), 2)
            cv2.putText(frame, "Green: kalman_predict_box, slightly enlarge ", (0, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            cv2.putText(frame, "Red: yolov4_predict_box, normal size box ", (0, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

            if np.any(np.isnan(pos)):
                to_del.append(t)
        trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
        for t in reversed(to_del):
            trackers.pop(t)
        # associate_detections_to_trackers:详细过程见 https://blog.csdn.net/DeepCBW/article/details/124740092
        matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, iou_threshold)
        print('unmatched_dets:', unmatched_dets)
        # update matched trackers with assigned detections
        # 只有匹配上才会update,此时,每个track的time_since_update重新归置为0,
        # 没有匹配上轨迹的detections,将会赋予新的轨迹。
        # 新轨迹连续匹配上update时,hit_streak += 1,大于min_hits时才会赋予新的id。
        # 新轨迹未如果没匹配上detection,这个track的time_since_update+=1,重新将hit_streak = 0,直到这个trk.time_since_update>max_age,删除该轨迹。
        for m in matched:
            trackers[m[1]].update(dets[m[0], :])
        # create and initialise new trackers for unmatched detections
        for i in unmatched_dets:
            trk = KalmanBoxTracker(dets[i, :])
            trackers.append(trk)
        i = len(trackers)
        for trk in reversed(trackers):
            d = trk.get_state()[0]
            print('trk.hit_streak:', trk.hit_streak)
            if (trk.time_since_update < 1) and (trk.hit_streak >= min_hits or frame_count <= min_hits):
                ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1))  # +1 as MOT benchmark requires positive
            i -= 1
            # remove dead tracklet
            if trk.time_since_update > max_age:
                trackers.pop(i)
        if len(ret) > 0:
            trajectories = np.concatenate(ret)
        else:
            trajectories = np.empty((0, 5))
        # 根据跟踪结果,绘图
        for i in trajectories:
            track_id = i[4]
            print('track_id:', track_id)
            x1, y1, x2, y2 = int(i[1]), int(i[0]), int(i[3]), int(i[2])
            print('box:', x1, y1, x2, y2)
            cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 2)
            cv2.putText(frame, "track_id: %d " % track_id, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
        cv2.putText(frame, "frame_id: %d " % count_id, (100, 200), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 2)
        out.write(frame)
        cv2.imshow('frame', frame)
        cv2.waitKey(1)
print("Video Detection Done!")
capture.release()
print("Save processed video to the path:", video_save_path)
out.release()
cv2.destroyAllWindows()
yolo.close_session()

你可能感兴趣的:(目标检测,匈牙利,卡尔曼,目标跟踪,目标检测,yolov4,sort)