使用OpenCV与Python从零搭建多目标跟踪框架

文章目录

  • 一、OpenCV是什么?
  • 二、搭建多目标跟踪框架
    • 1.引入OpenCV库
    • 2.从视频中获取帧
    • 3.用矩形框将目标框出
    • 4.分配目标ID


【博主使用的Python版本:3.9.7】
【博主使用的 OpenCV版本:4.5.0】


本文所使用的资料已上传到百度网盘【https://pan.baidu.com/s/1-OyW8kGbfV58bO4q3GK0tA?pwd=j7u9】,提取码:j7u9。


一、OpenCV是什么?

OpenCV的全称是:Open Source Computer Vision Library, OpenCV是一个基于Apache2.0许可(开源)发行的跨平台计算机视觉和机器学习软件库, 其采用 C/C++ 编写,同时提供了Python、Ruby、MATLAB等语言的接口,实现了图像处理和计算机视觉方面的很多通用算法。其主要关注的是实时应用,同时,OpenCV 的另一个目标是构建一个简单易用的计算机视觉框架,以帮助开发人员更便捷地设计更复杂的计算机视觉相关的应用程序。可以从http://opencv.org获取。

二、搭建多目标跟踪框架

我们要做的事是搭建一个【多目标跟踪】的简单的框架,你可以跟随我的步骤在PyCharm中一步步地把代码填进去,也可以直接复制完整代码,完整代码在本文最底部。

1.引入OpenCV库

在开始之前,我们需要引入的库:

  • numpy :是用Python进行科学计算的基本软件包。
  • math:是主要处理数学相关的运算的常用软件包。
  • cv2:是一个著名的计算机视觉库,用于图像处理分析。
  • object_detection:在本文的资料包里,用于检测目标。

如果你没有以上的库,请自行安装。

import cv2
import numpy as np
from object_detection import ObjectDetection
import math

object_detection.py代码如下:

import cv2
import numpy as np

class ObjectDetection:
    def __init__(self, weights_path="dnn_model/yolov4.weights", cfg_path="dnn_model/yolov4.cfg"):
        print("Loading Object Detection")
        print("Running opencv dnn with YOLOv4")
        self.nmsThreshold = 0.4
        self.confThreshold = 0.5
        self.image_size = 608

        # Load Network
        net = cv2.dnn.readNet(weights_path, cfg_path)

        # Enable GPU CUDA
        net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
        net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
        self.model = cv2.dnn_DetectionModel(net)

        self.classes = []
        self.load_class_names()
        self.colors = np.random.uniform(0, 255, size=(80, 3))

        self.model.setInputParams(size=(self.image_size, self.image_size), scale=1/255)

    def load_class_names(self, classes_path="dnn_model/classes.txt"):

        with open(classes_path, "r") as file_object:
            for class_name in file_object.readlines():
                class_name = class_name.strip()
                self.classes.append(class_name)

        self.colors = np.random.uniform(0, 255, size=(80, 3))
        return self.classes

    def detect(self, frame):
        return self.model.detect(frame, nmsThreshold=self.nmsThreshold, confThreshold=self.confThreshold)

2.从视频中获取帧

在开始之前首先进行第一个测试,以确保后面所做的都是正确的。

cap = cv2.VideoCapture("los_angeles.mp4")      # 加载视频片段

_, frame = cap.read()            # 从视频中获取帧

cv2.imshow("Frame", frame)       # 显示帧
cv2.waitKey(0)                   # 保持窗口打开

运行结果如下:

现在我们成功获取并加载显示出了视频的第一帧,这是一个好的开始。有了第一帧,接下来,继续加载整个视频。

什么是视频?
通俗的来讲,视频就是一个接一个的大量图像。如果你检查相机的规格,例如相机可以以30fps的速度录制,这就意味着相机每秒记录30帧,说明在一秒钟内有30张图像。

所以现在我们将获取帧的步骤放入循环中,在循环内一个接一个地获取帧。注:这里要判断视频是否播放完,即判断是否存在帧,如果不存在帧,则退出循环。

cap = cv2.VideoCapture("los_angeles.mp4")      # 加载视频片段

while True:                         # 获取连续的帧
    ret, frame = cap.read()         # 从视频中获取帧
    if not ret:                     # 是否存在帧
         break                      # 如果不存在帧,则退出

    cv2.imshow("Frame", frame)      # 显示帧

    key = cv2.waitKey(1)            # 1:每帧延迟1ms 
    if key == 27:                   # 注:ESC键
        break

cap.release()                       # 释放视频文件
cv2.destroyAllWindows()             # 关闭所有窗口

运行结果如下(按下ESC键,退出循环):

(这里只能上传5M以内的图,所以上传的动图压缩了,略模糊)

3.用矩形框将目标框出

已经确保了视频每帧都能成功获取,现在调用object_detection.py中目标检测函数,获取每帧中的包含的目标信息,代码如下:

cap = cv2.VideoCapture("los_angeles.mp4")      # 加载视频片段

while True:                                   # 获取连续的帧
    ret, frame = cap.read()                   # 从视频中获取帧
    if not ret:                               # 是否存在帧
         break                                # 如果不存在帧,则退出

    (class_ids, scores, boxes) = od.detect(frame)   # 当前帧中检测目标中的信息
    # class_id:what object is (car / track / person)
    # score: how confident is about the detection and
    # box: bounding box of the location of each object

    for box in boxes:                         # 不区分类别,只画框
        print(box)                            # 打印框,确保提取目标正确

    cv2.imshow("Frame", frame)                # 显示帧

    key = cv2.waitKey(1)                      # 1:每帧延迟1ms 
    if key == 27:                             # 注:ESC键
        break

cap.release()                                 # 释放视频文件
cv2.destroyAllWindows()                       # 关闭所有窗口

运行结果如下(这里只放了部分):

[505 802 133 178]
[376 683 122 118]
[1671  603  159   64]
[727 605  68  87]
[972 610  92  74]
[898 508  61  52]
[826 531  61  69]
[592 457  40  32]
[861 457  39  32]
[1214  880  244  199]
[735 445  36  37]
[1100  424   37   25]
[1835  560   85   91]

以上结果中,每一行代表一个目标的信息,其中前两个数字为目标框的左上角点坐标(x,y),第三个数字为目标框的宽度,第四个数字为目标框的高度。知道这些信息后,我们就可以在每帧中画出矩形框,框出目标。代码如下:

    (class_ids, scores, boxes) = od.detect(frame)   # 当前帧中检测目标中的信息
    # class_id:what object is (car / track / person)
    # score: how confident is about the detection and
    # box: bounding box of the location of each object

    for box in boxes:                                         # 不区分类别,只画框
        (x, y, w, h) = box                                    # 矩形的左上角坐标,及宽度和高度
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)  # 在当前帧中根据box绘制矩形

运行结果:

4.分配目标ID

接下来可以给每个跟踪目标分配ID,确保跟踪的为同一个目标。

 for box in boxes:                    # 不区分类别,只画框
        # print(box)                  # 打印框,确保提取目标正确
        (x, y, w, h) = box            # 矩形的左上角坐标,及宽度和高度
        cx = int((x + x + w) / 2)     # 中心点的x坐标
        cy = int((y + y + h) / 2)     # 中心点的y坐标
        center_points_cur_frame.append((cx, cy))       # 添加新的中心点到数组中

        # print("FRAME N°", count, " ", x, y, w, h)    # 打印每一帧中的框

        # cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1)             # 在当前帧中以框的中心点为中心画圆,半径为5,红色,用所有颜色填充圆圈
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)    # 在当前帧中根据box绘制矩形

    # only at the beginning we compare previous and current frame
    if count <= 2:
        for pt in center_points_cur_frame:
            # cv2.circle(frame, pt, 5, (0, 0, 255), -1)                # 画出所有中心点
            for pt2 in center_points_prev_frame:
                distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1])  # 当前帧与前一帧的目标中心点的距离

                if distance < 20:                          # 距离小于20像素
                    tracking_objects[track_id] = pt        # 当前帧中的目标中心
                    track_id += 1
    else:

        tracking_objects_copy = tracking_objects.copy()                # 建立跟踪目标字典副本
        center_points_cur_frame_copy =  center_points_cur_frame.copy()

        for object_id, pt2 in tracking_objects.copy().items():         # 首先遍历新的数组
            object_exists = False                                      # 首先假设当前帧中不存在目标
            for pt in center_points_cur_frame:                         # 遍历当前帧中的目标
                distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1])  # 当前帧与前一帧的目标中心点的距离

                # Update IDs position
                if distance < 20:
                    tracking_objects[object_id] = pt                   # 更新目标位置
                    object_exists = True                               # 目标存在
                    if pt in center_points_cur_frame:
                        center_points_cur_frame.remove(pt)
                        continue                                           # 继续下一帧

            # Remove IDs lost
            if not object_exists:                 # 目标不存在
                tracking_objects.pop(object_id)   # 则移除目标ID

        # Add new IDs found
        for pt in center_points_cur_frame:
            tracking_objects[track_id] = pt
            track_id += 1

    for object_id, pt in tracking_objects.items():
        cv2.circle(frame, pt, 5, (0, 0, 255), -1)      # 画出所有中心点
        # 在帧中显示目标id,文本位置,文本字体类型,字体大小,颜色,粗细(注:需>=0)
        cv2.putText(frame, str(object_id), (pt[0], pt[1] - 7), 0, 1, (0, 0, 255), 0)

运行结果:

object_tracking.py完整代码如下:

import cv2
import numpy as np
from object_detection import ObjectDetection
import math

# Initialize Object Detection
od = ObjectDetection()                         # 加载目标

cap = cv2.VideoCapture("los_angeles.mp4")      # 加载视频片段
frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)     # 通过属性获取帧数

# Initialize count
count = 0                         # 用于计算视频的实际帧数

# center_points = []              # 用于储存所有的中心点

center_points_prev_frame = []     # 空数组用于储存第一帧前的空帧

tracking_objects = {}            # 用于储存跟踪目标
track_id = 0                     # 跟踪目标初始序号

print(cv2.__version__)

while True:                        # 获取连续的帧
    ret, frame = cap.read()        # 从视频中获取帧
    count += 1
    if not ret:                    # 是否存在帧
        break                      # 如果不存在帧,则退出

    # point current frame
    center_points_cur_frame = []   # 用于储存当前帧的目标的中心点

    # Detect objects on frame
    (class_ids, scores, boxes) = od.detect(frame)         # 当前帧中检测目标中的信息
    # class_id:what object is (car / track / person)
    # score: how confident is about the detection and
    # box: bounding box of the location of each object

    for box in boxes:              # 不区分类别,只画框
        # print(box)               # 打印框,确保提取目标正确
        (x, y, w, h) = box           # 矩形的左上角坐标,及宽度和高度
        cx = int((x + x + w) / 2)   # 中心点的x坐标
        cy = int((y + y + h) / 2)   # 中心点的y坐标
        center_points_cur_frame.append((cx, cy))    # 添加新的中心点到数组中

        # print("FRAME N°", count, " ", x, y, w, h)    # 打印每一帧中的框

        # cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1)             # 在当前帧中以框的中心点为中心画圆,半径为5,红色,用所有颜色填充圆圈
        cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)    # 在当前帧中根据box绘制矩形

    # only at the beginning we compare previous and current frame
    if count <= 2:
        for pt in center_points_cur_frame:
            # cv2.circle(frame, pt, 5, (0, 0, 255), -1)                # 画出所有中心点
            for pt2 in center_points_prev_frame:
                distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1])  # 当前帧与前一帧的目标中心点的距离

                if distance < 20:                          # 距离小于20像素
                    tracking_objects[track_id] = pt        # 当前帧中的目标中心
                    track_id += 1
    else:

        tracking_objects_copy = tracking_objects.copy()                # 建立跟踪目标字典副本
        center_points_cur_frame_copy =  center_points_cur_frame.copy()

        for object_id, pt2 in tracking_objects.copy().items():         # 首先遍历新的数组
            object_exists = False                                      # 首先假设当前帧中不存在目标
            for pt in center_points_cur_frame:                         # 遍历当前帧中的目标
                distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1])  # 当前帧与前一帧的目标中心点的距离

                # Update IDs position
                if distance < 20:
                    tracking_objects[object_id] = pt                   # 更新目标位置
                    object_exists = True                               # 目标存在
                    if pt in center_points_cur_frame:
                        center_points_cur_frame.remove(pt)
                        continue                                           # 继续下一帧

            # Remove IDs lost
            if not object_exists:                 # 目标不存在
                tracking_objects.pop(object_id)   # 则移除目标ID

        # Add new IDs found
        for pt in center_points_cur_frame:
            tracking_objects[track_id] = pt
            track_id += 1

    for object_id, pt in tracking_objects.items():
        cv2.circle(frame, pt, 5, (0, 0, 255), -1)      # 画出所有中心点
        # 在帧中显示目标id,文本位置,文本字体类型,字体大小,颜色,粗细(注:需>=0)
        cv2.putText(frame, str(object_id), (pt[0], pt[1] - 7), 0, 1, (0, 0, 255), 0)



    print("Tracking Objects :")
    print(tracking_objects)

    print("CUR FRAME LEFT PTS :")
    print(center_points_cur_frame)

    # print("PREV FRAME :")
    # print(center_points_prev_frame)

    cv2.imshow("Frame", frame)     # 显示帧

    # Make a copy of the points
    center_points_prev_frame = center_points_cur_frame.copy()

    key = cv2.waitKey(0)            # 1:每帧延迟1ms 0:保持当前帧不动
    if key == 27:                   # 注:ESC键
        break

cap.release()                      # 释放视频文件
cv2.destroyAllWindows()            # 关闭所有窗口

print("通过属性获取的视频帧数 :", frames)
print("实际遍历整个视频的帧数 :", count-1)

你可能感兴趣的:(python,opencv,目标跟踪)