yolov8(8.2.10)+deepsort(demo)

只需要训练好yolov8的检测模型然后调用:

results = model.track(frame, persist=True)  # 执行跟踪, persist=True 表示持续跟踪。保持同一个人在多帧画面的id 一

就可以

完整代码:

import cv2
import numpy as np
from ultralytics import YOLO
from collections import defaultdict

# 框的中心点的历史轨迹
track_history = defaultdict(lambda :[])   # 创建一个默认值为列表类型的字典

model = YOLO(r"D:\Project_YC\yl8-8.2.10_rubish\runs\detect\train7\weights\best.pt")

cls_show = ["person", "car", "bus", "truck"]

video_path = r'video/person_rubbish.mp4'
out_video_path = 'out_video2.mp4'

cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
    print("打开失败")
    exit()

w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)

fourcc = cv2.VideoWriter_fourcc(*"XVID")

out = cv2.VideoWriter(out_video_path, fourcc, fps, (w, h))
videoWriter = None

rectangel_np = np.array([[710, 200], [1110, 200], [810, 400], [410, 400]], np.int32)
pts = rectangel_np.reshape((-1, 1, 2))


while True:
    ret, frame = cap.read()
    if not ret:
        print("读取帧失败")
        break

    results = model.track(frame, persist=True)  # 执行跟踪, persist=True 表示持续跟踪。保持同一个人在多帧画面的id 一

    # plot()方法将这些追踪结果绘制在原始图像上
    a_frame = results[0].plot()
    # cv2.imshow('11', a_frame)
    # cv2.waitKey(0)

    labels = results[0].names  # 类别名称
    print("类别名称: ", labels)
    boxes = results[0].boxes.xywh.cpu()
    print(results[0].boxes.id, '-------')

    tracks_ids = results[0].boxes.id.int().cpu().tolist() # 把tensor类型的id 转成列表类型
    pre_labels = results[0].boxes.cls.cpu().tolist()
    print(pre_labels)


    for box, track_id in zip(boxes, tracks_ids):
        x, y, w, h = box
        # 获取键 track_id 对应的值,如果不存在,自动生成一个空列表, 同时在track_history里插入一个键名:track_id, 如果 track_id 已经存在于 track_history 中,则直接返回对应的值(一个列表)
        track = track_history[track_id]             # track:   track_id 对应的值,如果不存在,自动生成一个空列表, 这里目前不存在值,所以给键的是一个自动生成的空列表
        track.append((float(x), float(y)))          # 把 (float(x), float(y)) 添加到值里

        # 如果track里的值超过50个,则把track的第一个值剔除。既当(x,y)这个中心点元组存储超过50个就最前面那个剔除。保证列表始终保持最多 50 个坐标点,这是为了限制跟踪历史的长度,防止内存占用过多或数据冗长
        if len(track) > 50:
            track.pop(0)

        # 一个id 的累计轨迹点
        points = np.hstack(track).astype(np.int32).reshape(-1, 1, 2)
        cv2.polylines(a_frame, [points], isClosed=False, color=(255, 0, 255), thickness=3)

    out.write(a_frame)
    cv2.imshow("88", a_frame)
    cv2.waitKey(1)


cap.release()
cv2.destroyWindow()












你可能感兴趣的:(YOLO)